import React, { useRef, useMemo, useState } from 'react'
import { Math as THREEMath, MeshStandardMaterial, ShaderMaterial, DoubleSide, TextureLoader, UniformsUtils, DataTexture, Matrix4, MeshBasicMaterial} from 'three'
import { CustomBlending, AddEquation, SrcAlphaFactor, LessEqualDepth,
     LuminanceFormat, FloatType, DataTexture3D, ImageLoader, NearestFilter, LinearFilter} from 'three'
import { useFrame, useLoader,  } from 'react-three-fiber'
import {NeuronVisShader} from './Shaders'
import {getNeuronThumbLayout} from './NeuronVisPlaneUtil'
import { time } from '@tensorflow/tfjs'
import {getLayerImageURL} from '../../Model/ModelMetaInfo'

function getMaterial() {
    const material = new ShaderMaterial({
        uniforms: UniformsUtils.clone(NeuronVisShader.uniforms),
        vertexShader: NeuronVisShader.vertexShader,
        fragmentShader: NeuronVisShader.fragmentShader,
    });
    return material;
}

export function NeuronVisPlane(props) {
    const {model, layer, width, height, layerActivations, normalize,
    activationParams, setHoveredNeuron, index} = props;
    const [frame, setFrame] = useState(0);
    const imPath = getLayerImageURL(model, layer);
    const layerImage = useLoader(ImageLoader,imPath);
    const context = useMemo(() => {
        const canvas = document.createElement('canvas');
        const context = canvas.getContext('2d');
        canvas.width = layerImage.width;
        canvas.height = layerImage.height;
        context.drawImage(layerImage, 0, 0);
        return context;
    }, [layer]);
    const showAll = activationParams[index].showAllNeurons;
    const insts = layerImage.width / 64;
    const instsToDisplay = showAll ? insts : 32;
    const thirdCoords = useMemo(() => {
        const coord = new Float32Array(instsToDisplay);
        for (let i = 0; i < instsToDisplay; i++) {
            coord[i] = (i+0.5)/(instsToDisplay);
        }
        return coord;
      }, [instsToDisplay]);
    const thumbLayout = useMemo(() => {
        return getNeuronThumbLayout(width, height, instsToDisplay);
    }, [instsToDisplay, width, height]);
    const activationBuffer = useMemo(() => {
        const arr =  new Float32Array(instsToDisplay);
        for(let i=0; i<instsToDisplay; i++) {
            arr[i] = i/instsToDisplay;
        }
        return arr;
    }, [instsToDisplay]);
    let imData;
    let texture3d;
    const material = useMemo(() => getMaterial(), []);
    const mesh = useRef();
    const attrib = useRef();
    const activAttrib = useRef();

    if(mesh.current){
        mesh.current.count = instsToDisplay;
    }
    
    useFrame(() => {
        if(!imData) {
            imData = context.getImageData(0, 0, layerImage.width, layerImage.height);
            texture3d = new DataTexture3D(imData.data, 64, insts, 64);
            texture3d.minFilter = LinearFilter;
            texture3d.magFilter = LinearFilter;
        }
        mesh.current.material = material;
        material.uniforms.dataTexture.value = texture3d;
        material.uniforms.normalize.value = normalize;
        const ms = 0;
        for(let i=0; i<instsToDisplay; i++){
            const {x, y, w, h} = thumbLayout[i];
            const transMat = new Matrix4().makeTranslation(x, y, 0);
            const scaleMat = new Matrix4().makeScale(w, h, 1);
            mesh.current.setMatrixAt(i, transMat.multiply(scaleMat));
        }
        mesh.current.instanceMatrix.needsUpdate = true;
        attrib.current.needsUpdate = true;
        if(layerActivations[index].data) {
            const argArray = Array.from(layerActivations[index].data).map((d, i) => [d, i]);
            let sorted;
            if(!showAll) {
                sorted = argArray.sort(([a1], [a2]) => a2-a1);
            } else {
                sorted = argArray;
            }
            const data = sorted.map(([d,]) => d);
            const inds = sorted.map(([,i]) => (i+0.5)/insts);
            activationBuffer.set(data.slice(0, instsToDisplay));
            thirdCoords.set(inds.slice(0, instsToDisplay));
        }
        activAttrib.current.needsUpdate = true;
    });
    return (
        <instancedMesh args={[null, null, 1024]}
        onPointerMove={e => {
            const neuron = Math.round(thirdCoords[e.instanceId]*insts-0.5);
            setHoveredNeuron(
                {neuron: neuron, left:e.unprojectedPoint.x, top: -e.unprojectedPoint.y});
        }} onPointerOut={e => setHoveredNeuron({neuron: undefined, left:0, top: 0})}
            ref={mesh} >
            <planeBufferGeometry attach="geometry" args={[1, 1, 1]} >
            <instancedBufferAttribute ref={attrib} attachObject={['attributes', 'thirdCoord']} args={[thirdCoords, 1]} />
            <instancedBufferAttribute ref={activAttrib} attachObject={['attributes', 'activation']} args={[activationBuffer, 1]} />
            </planeBufferGeometry>
            <meshBasicMaterial attach="material" />
        </instancedMesh>
    );
}