import {
	MeshStandardMaterial,
	RepeatWrapping,
	type Color,
	type Material,
	type Texture,
} from "three";
import { ExtendedMaterial } from "three-extended-material";
const tempCachedName = "burnablematerial";

export type TBurnableMaterial = {
	noiseTexture: Texture;
	progress?: number;
	time?: number;
	timeScale?: number;
	scale?: number;
	color?: Color;
} & Partial<Material>;

function createBurnableMaterial(
	material: Partial<TBurnableMaterial>,
): TBurnableMaterial {
	if (material.noiseTexture) {
		material.noiseTexture.wrapS = material.noiseTexture.wrapT = RepeatWrapping;
	}
	const burnMat = new ExtendedMaterial(
		MeshStandardMaterial,
		[
			{
				name: `${tempCachedName}`,
				uniforms: {
					noiseTexture: material.noiseTexture,
					progress: material.progress || 0,
					time: material.time || 0,
					timeScale: 7,
					scale: material.scale || 0.55,
				},
				vertexShader: (shader: string) => {
					const patchedVertexShader = `
      #define USE_MAP
      #define USE_ENVMAP
      
      uniform float time;
      uniform float timeScale;
      varying vec3 vWorldPosition;
      varying vec3 vWorldNormal;
      attribute float progress;
      varying float vProgress;

      ${shader.replace(
				"void main() {",
				`
        void main() {
          vWorldPosition = (modelMatrix * vec4(position, 1.0)).xyz;
          vWorldNormal = normalize(mat3(modelMatrix) * normal);
          vProgress = progress;
        `,
			)}
      `;
					return patchedVertexShader;
				},
				fragmentShader: (shader: string) => {
					const patchedShader = `
        #define USE_MAP

        varying vec3 vWorldPosition;
        varying vec3 vWorldNormal;
        varying float vProgress;
        
        uniform sampler2D noiseTexture;
        uniform float time;
        uniform float timeScale;
        uniform float scale;
        
        // Function to calculate proper triplanar mapping
        vec4 triplanarMapping(sampler2D tex, vec3 blendWeights, vec3 position, float time, float timeScale, float scale) {
            // Normalize blend weights
            blendWeights = normalize(max(blendWeights, 0.00001));
            
            // Enhance blend weights for sharper transitions
            // blendWeights = pow(blendWeights, vec3(4.0));
            // blendWeights = blendWeights / (blendWeights.x + blendWeights.y + blendWeights.z);
            
            // Calculate texture coordinates with improved time-based panning
            float scaledTime = time * timeScale;
            vec2 texCoordsX = abs(position.zy) * scale + vec2(scaledTime, -scaledTime) + 0.1;
            vec2 texCoordsY = abs(position.xz) * scale + vec2(-scaledTime, -scaledTime) + 0.1;
            vec2 texCoordsZ = abs(position.xy) * scale + vec2(-scaledTime, scaledTime) + 0.1;
            
            // Sample textures for each axis
            vec4 colorX = texture2D(tex, texCoordsX);
            vec4 colorY = texture2D(tex, texCoordsY);
            vec4 colorZ = texture2D(tex, texCoordsZ);
            
            // Combine textures based on blending weights
            return colorX * blendWeights.x + colorY * blendWeights.y + colorZ * blendWeights.z;
        }

        ${shader.replace(
					"#include <transmission_fragment>",
					`#include <transmission_fragment>
            // Calculate blending weights based on the absolute world normal
            vec3 blendWeights = abs(normalize(vWorldNormal));

            float burnProgress = vProgress * 2.0;
            float coalProgress = abs(vProgress * 2.0 - 1.0) ;
            float invBurnProgress = 1.0 - burnProgress;
            float invCoalProgress = 1.0 - coalProgress;

            // Calculate the triplanar mapping color with proper blending and scaling
            vec4 emberTex = triplanarMapping(noiseTexture, blendWeights, vWorldPosition, time, timeScale, scale);
            vec4 coalTex = vec4(triplanarMapping(noiseTexture, blendWeights, vWorldPosition, 1.0, 1.0, scale * 10.f).r) * 0.15f;

            // Get burn color and modify the fragment output accordingly
            vec3 burnColor = vec3(emberTex.r * emberTex.r * 4.0, 0.0, 0.0) * burnProgress;

            // fade out the original texture as progress reaches 1
            totalDiffuse.rgb *= invBurnProgress;
            vec3 burn = ((1.0 - burnProgress) * totalDiffuse.rgb) * (1.0 - burnColor.r) + burnColor.rgb;
            totalDiffuse.rgb = mix(totalDiffuse.rgb, vec3(0.0), vProgress);
            vec3 coal = pow(mix(totalDiffuse.rgb, coalTex.rgb, coalProgress), vec3(2.0));
            burn = mix(burn, coal, coalProgress);
            totalDiffuse.rgb = mix(totalDiffuse.rgb, burn, burnProgress);

            vec3 emissive = mix(mix(vec3(0.0), burnColor.rgb, burnProgress), vec3(0.0), coalProgress);

            totalEmissiveRadiance += emissive;
            totalSpecular += emissive;
          `,
				)}
      `;
					return patchedShader;
				},
			},
		],
		material,
	);

	return burnMat;
}

export { createBurnableMaterial };
