import {Matrix4, UniformsUtils, UniformsLib} from 'three'

const identity = new Matrix4();

export default {
    uniforms: UniformsUtils.merge([
        UniformsLib['lights'],
        {
            colorMap: { type: 't', value: null},
            normalMap: { type: 't', value: null},
            roughnessMap: { type: 't', value: null},
            alphaMap: { type: 't', value: null},
            shininess: {value: 5.0},
            specularFact: {value: 0.3},
            diffuseFact: {value: 1.0},
            ambientFact: {value: 1.0},
            textureInfluence: {value: 1.0},
            lightingInfluence: {value: 1.0},
            distortion: {value: 0.0},
            postRot: {value: identity},
            preRot: {value: identity}
        }
    ]),
  
    vertexShader: [
        'varying vec2 vUV;',
        'varying vec3 vPos;',
        'varying vec3 vNormal;',
        'uniform mat4 postRot;',
        'uniform mat4 preRot;',

        'uniform float distortion;',

        'void main(void) {',
        '    vec3 centr = vec3(0.0,0.55,0.4);',
        '    vec3 distortedPos = centr+normalize(position-centr);',
        '    vec3 tempPos = position*2.0;',
        '    vec3 cell = floor(tempPos);',
        '    vec3 frac = tempPos - cell;',
        '    distortedPos *= frac - vec3(0.5);',
        '    vec3 pos = distortion*distortedPos+(1.0-distortion)*position;',
        '    gl_Position = projectionMatrix * postRot * viewMatrix * preRot * modelMatrix * vec4(pos, 1.0);',
        '    vUV = uv;',
        '    vPos = (postRot * viewMatrix * preRot * modelMatrix * vec4(pos, 1.0)).xyz;',
        '    vNormal = normalMatrix * normal;',
        '}'
    ].join('\n'),
  
    fragmentShader: [
        'precision highp float;',
        'precision highp sampler2D;',
        
        'uniform sampler2D colorMap;',
        'uniform sampler2D normalMap;',
        'uniform sampler2D roughnessMap;',
        'uniform sampler2D alphaMap;',

        'uniform float shininess;',
        'uniform float specularFact;',
        'uniform float diffuseFact;',
        'uniform float ambientFact;',

        'uniform float textureInfluence;',
        'uniform float lightingInfluence;',
        
        'varying vec2 vUV;',
        'varying vec3 vPos;',
        'varying vec3 vNormal;',

        'struct PointLight {',
        '    vec3 position;',
        '    vec3 color;',
        '  };',
        'uniform PointLight pointLights[ NUM_POINT_LIGHTS ];',
        'uniform vec3 ambientLightColor;',
        
        'void main(void) {',
        '    vec3 col = texture2D(colorMap, vUV).rgb;',
        '    float alpha = texture2D(alphaMap, vUV).r;',
        '    float roughness = texture2D(roughnessMap, vUV).r;',
        '    vec3 lambertCol = vec3(0.0);',
        '    vec3 blinnCol = vec3(0.0);',
        '    for(int i=0; i<NUM_POINT_LIGHTS; i++) {',
        '        vec3 lightPos = pointLights[i].position;',
        '    vec3 lightVec = lightPos - vPos;',
        '        float dist = length(lightVec);',
        '        lightVec /= dist;',
        '        float lambertian = dot(lightVec, vNormal);',
        '        lambertCol += diffuseFact * pointLights[i].color * lambertian;',
        '        if(lambertian > 0.0) {',
        '            float shadowFeather = 0.2;',
        '            shadowFeather = clamp(lambertian / shadowFeather, 0.0, 1.0);',
        '            vec3 viewDir = normalize(-vPos);',
        '            vec3 halfDir = normalize(viewDir + lightVec);',
        '            float specAngle = max(dot(halfDir, vNormal), 0.0);',
        '            float specular = pow(specAngle, shininess*(1.0-roughness));',
        '            blinnCol += specularFact*specular*pointLights[i].color*shadowFeather;',
        '        }',
        '    }',
        '    float t = textureInfluence;',
        '    float l = lightingInfluence;',
        '    vec3 litBase = l*lambertCol+l*ambientFact*ambientLightColor+vec3(1.0-l);',
        '    vec3 litTexturedBase = (col*t+vec3(1.0-t))*litBase;',
        '    vec3 finalCol = clamp(litTexturedBase + l * blinnCol, 0.0, 1.0);',
            
        '    gl_FragColor = vec4(finalCol, alpha);',
        '}'
    ].join('\n')
  }