import React from 'react';
import TextField from '@material-ui/core/TextField';
import { CustomSlider } from './ParamControls/CustomSlider';
import {
    FormControl, FormControlLabel, MenuItem,
    Select, InputLabel, IconButton, Switch
} from '@material-ui/core'
import { outputTypes } from '../Model/Outputs'
import { getCenterNeuronCoords } from '../Model/ModelUtils'
import { TuningCurveRenderer } from './TuningCurveComponents/TuningCurveRenderer'
import { TuningCurveLinePlot } from './TuningCurveComponents/TuningCurveLinePlot'

import { LayerView } from './LayerComponents/LayerView'
import { getPoolWidthFromDefParam } from '../Model/OutputUtils'
import { AddCircleOutline, RemoveCircleOutline } from '@material-ui/icons';
import {getNeuronThumbURL, archId, getDefaultNewActivationParams, modelDirNames} from '../Model/ModelMetaInfo'

import * as config from '../config'

export const defaultActivationParams = [
    {
        modelIndex: 0, layer: 'mixed4a_pre_relu', neurons: [308, 222, 234, 325],
        showAllNeurons: true, normalize: 0, poolWidth: 1
    },
    {
        modelIndex: 0, layer: 'mixed4b_pre_relu', neurons: [443, 429],
        showAllNeurons: false, normalize: 0, poolWidth: 1
    }]

export function getDefaultActivations() {
    return {
        act: {
            0: { data: new Float32Array([0.5, -.5, 0.2, 0.3]), shape: { w: 1, h: 1 } },
        }
    };
}

const defaultNewActivations = () => { return { data: new Float32Array([0]), shape: { w: 1, h: 1 } } };

function addActivationParams(archType, activationParams, setActivationParams,
    activations, setActivations, modelInds, setModelInds) {
    const num = Object.keys(activations.act).length;
    activations.act[num] = defaultNewActivations();

    const newModelInds = [...modelInds];
    newModelInds.push(0)

    activationParams.push(getDefaultNewActivationParams(archType));

    setActivations(activations);
    setActivationParams(activationParams);
    setModelInds(newModelInds);
}

export function getActivationViewKey(index) {
    return "ActivationView"+index;
}

function removeActivationParams(modelManager, activationParams, setActivationParams,
    activations, setActivations, modelInds, setModelInds) {
    const num = Object.keys(activations.act).length;
    delete activations.act[num-1];

    const newModelInds = [...modelInds];
    const removedInd = newModelInds.pop();
    modelManager.getModel(removedInd).outputManager.deregisterOutputDefGenerators( getActivationViewKey(newModelInds.length));

    activationParams.pop();

    setActivations(activations);
    setActivationParams(activationParams);
    setModelInds(newModelInds);
}

export function getPlusMinusView(modelManager, activationParams, setActivationParams,
    activations, setActivations, modelInds, setModelInds) {
    const displayRemoveButton = Object.keys(activationParams).length > 1;
    const model = modelManager.getModel(modelInds[modelInds.length-1]);
    return (
        <div className="column" style={{ height: "100%" }}>
            <IconButton size="large"
                onClick={() => addActivationParams(model.archType,
                    activationParams, setActivationParams,
                    activations, setActivations, modelInds, setModelInds)}
            >
                <AddCircleOutline fontSize="inherit" />
            </IconButton>
            {displayRemoveButton ? <IconButton size="large"
                onClick={() => removeActivationParams(modelManager, activationParams, setActivationParams,
                    activations, setActivations, modelInds, setModelInds)}
            >
                <RemoveCircleOutline fontSize="inherit" />
            </IconButton> : ""}
        </div>);
}

function getNeuronString(params) {
    let neuronString = ""
    params.neurons.forEach((n, i) => {
        if (i) {
            neuronString += " ";
        }
        neuronString += n;
    });
    return neuronString;
}

function getFeatureVisRow(model, params, neuronWidth) {
    return params.neurons.map((n, i) => {
        const src = getNeuronThumbURL(model, params.layer, n);
        return (
            <img width={neuronWidth} height={neuronWidth} src={src}></img>
        )
    })
}

function getModelSelect(modelNames, modelInd, onChange) {
    return (
        <FormControl >
            <InputLabel id="demo-simple-select-label">Model</InputLabel>
            <Select
                labelId="demo-simple-select-label"
                id="demo-simple-select"
                value={modelInd}
                onChange={onChange}
            >
                {
                    modelNames.map((l, i) => {
                        return <MenuItem value={i}>{l}</MenuItem>
                    })
                }
            </Select>
        </FormControl>)
}

function getLayerSelect(index, layers, layer, onChange) {
    return (
        <FormControl >
            <InputLabel id="demo-simple-select-label">Layer</InputLabel>
            <Select
                labelId="demo-simple-select-label"
                id="demo-simple-select"
                value={layer}
                onChange={onChange}
            >
                {
                    layers.map((l) => {
                        return <MenuItem value={l}>{l}</MenuItem>
                    })
                }
            </Select>
        </FormControl>)
}

export function getLayerActivationVisColumn(columnProps) {
    const { props, activationParams, setActivationParams,
        layerActivations, neuronWidth, updateReceptiveField,
    modelInds, setModelInds} = columnProps;

    const { modelManager, showSortedNeurons, showTuningCurves, tuningCurveResults } = props;

    return activationParams.map((params, index) => {
        const model = modelManager.getModel(modelInds[index], true);
        const modelLayers = model.getLayerNames();
        const width = neuronWidth * params.neurons.length;
        let neuronString = getNeuronString(params);
        const normString = activationParams[index].normalize == 0 ? 'layer stdev.' : Math.floor(Math.pow(10, activationParams[index].normalize));

        const getConditionalStyle = (show) => {
            return {
                transition: "height 1s",
                height: show ? neuronWidth : 0,
                overflow: "hidden"
            }
        }

        let modelNames = modelManager.modelDirs;

        return (
            <div className="column">
                <div className='row'>
                    {getFeatureVisRow(model, params, neuronWidth)}
                </div>
                {<div className='row'
                    style={getConditionalStyle(showSortedNeurons)}>
                    <LayerView
                        model={model}
                        activationParams={activationParams}
                        layerActivations={layerActivations}
                        width={width}
                        height={neuronWidth}
                        index={index} />
                </div>}
                <div className='row'
                    style={getConditionalStyle(showTuningCurves &&
                        Object.entries(tuningCurveResults).length !== 0)}>
                    <TuningCurveRenderer
                        tuningCurveResults={tuningCurveResults}
                        activationParams={activationParams}
                        width={width}
                        height={neuronWidth}
                        index={index} />
                </div>
                <div className='row'
                    style={getConditionalStyle(showTuningCurves &&
                        Object.entries(tuningCurveResults).length !== 0)}>
                    <TuningCurveLinePlot
                        tuningCurveResults={tuningCurveResults}
                        activationParams={activationParams}
                        width={width}
                        height={neuronWidth}
                        index={index} />
                </div>
                <div className={'column'} style={{ padding: "10px", width: width }}>
                    {getModelSelect(modelNames,
                     modelInds[index], (evt) => {
                         let newModelInds = [...modelInds]
                         newModelInds[index] = evt.target.value;
                        setModelInds(newModelInds);
                    })}
                    {getLayerSelect(index, modelLayers, params.layer, (evt) => {
                        let newParams = [...activationParams];
                        newParams[index].layer = evt.target.value;
                        setActivationParams(newParams);
                        updateReceptiveField();
                    })}
                    <TextField id="standard-required" label="Neurons" defaultValue={neuronString}
                        onChange={(evt) => {
                            const tokens = evt.target.value.split(" ");
                            const neurons = tokens.map((t) => parseInt(t));
                            let newParams = [...activationParams];
                            newParams[index].neurons = neurons;
                            setActivationParams(newParams);
                        }} />
                    <FormControlLabel
                        control={
                            <Switch checked={activationParams[index].showAllNeurons}
                                onChange={(evt) => {
                                    let newParams = [...activationParams];
                                    newParams[index].showAllNeurons = evt.target.checked;
                                    setActivationParams(newParams);
                                }} />}
                        label={activationParams[index].showAllNeurons ?
                            "Show all neurons" : "Show top neurons"}
                    />
                    <CustomSlider
                        labelText={"Pool width: " + (activationParams[index].poolWidth === -1 ? "pool all" : activationParams[index].poolWidth)}
                        valueLabelDisplay={"auto"}
                        step={2}
                        min={-1}
                        max={21}
                        value={activationParams[index].poolWidth}
                        onChange={(evt, value) => {
                            let newParams = [...activationParams];
                            newParams[index].poolWidth = value;
                            setActivationParams(newParams);
                        }}
                    />
                    <CustomSlider
                        labelText={"Normalization: div. by " + normString}
                        valueLabelDisplay={"auto"}
                        step={0.02}
                        min={0}
                        max={3}
                        value={activationParams[index].normalize}
                        onChange={(evt, value) => {
                            let newParams = [...activationParams];
                            newParams[index].normalize = value;
                            setActivationParams(newParams);
                        }}
                    />
                </div>
            </div>);
    })
}

/**
 * Generates function that can be registered with OutputManager of model to receive specified model output.
 * @param {*} actParams single activation parameter object (not to be confused with activationParams of ActivationView, encompassing several "actParams")
 * @param {*} callback callback that gets called by the inferencer when activation data is available.
 */
export function getOutputDefGenerator(actParams, callback, includeMoments = false, pool = false, layerOutput = false) {
    const type = outputTypes.NEURONS;
    const layer = actParams.layer;
    const channels = layerOutput ? -1 : actParams.neurons;
    const poolType = 'avg';
    const poolWidth = actParams.poolWidth;
    if (pool) {
        return () => {
            return {
                type: type, layer: layer, channels: channels,
                includeMoments: includeMoments, callback: callback,
                poolWidth: poolWidth, poolType: poolType
            };
        };
    } else {
        return () => {
            return {
                type: type, layer: layer, channels: channels,
                includeMoments: includeMoments, callback: callback
            };
        };
    }
}

export function getNeuronIndicatorDivs(neuronWidth, activations, activationParams) {
    let divArray = [];
    let currentLeft = 0;
    activationParams.forEach((params, p) => {
        const shape = activations.act[p].shape;
        for (let i = 0; i < params.neurons.length; i++) {
            const centerNeuronWidth = neuronWidth / shape.w;
            const { x, y } = getCenterNeuronCoords(shape);
            const poolWidth = getPoolWidthFromDefParam(params.poolWidth, shape.h, shape.w);
            const cnMarginX = currentLeft + (x - Math.floor(poolWidth / 2)) * centerNeuronWidth;
            const cnMarginY = (y - Math.floor(poolWidth / 2)) * centerNeuronWidth;
            currentLeft += neuronWidth;

            divArray.push(<div style={{
                position: 'absolute',
                pointerEvents: 'none',
                top: cnMarginY,
                left: cnMarginX,
                width: centerNeuronWidth * poolWidth,
                height: centerNeuronWidth * poolWidth,
                zIndex: 1,
                border: '2px dotted white',
                visibility: config.SHOW_RECEPTIVE_FIELD ? 'visible' : 'hidden'
            }}>
            </div>);
        }
    });
    return divArray;
}

const getNeuronInferenceCallback = (activations, setActivations, i) => (ret) => {
    const actCopy = { ...activations };
    let newActivations = actCopy;
    if ((i in newActivations.act)) {
        newActivations.act[i].data = ret.data;
        newActivations.act[i].shape.w = ret.w;
        newActivations.act[i].shape.h = ret.h;
        if ('variance' in ret) {
            newActivations.act[i].normalize = 1 / Math.sqrt(ret.variance);
        } else {
            delete newActivations.act[i].normalize;
        }
    }
    setActivations(newActivations);
};

const getLayerInferenceCallback = (layerActivations, setLayerActivations, activationParams, i) => (ret) => {
    let newActivations = layerActivations;
    if (!newActivations) {newActivations = {}}
    if (!(i in newActivations)) { newActivations[i] = {} }
    newActivations[i].data = ret.data;
    if ('variance' in ret) {
        newActivations[i].normalize = 1 / Math.sqrt(ret.variance);
    } else {
        newActivations[i].normalize = 1 / Math.pow(10, activationParams.normalize);
    }
    setLayerActivations(newActivations);
};

const getAnimationCallback = (animation, activationParams, paramN, i) => (inferenceData, metaData) => {
    if (!('variance' in inferenceData)) {
        inferenceData.mean = 0;
        inferenceData.variance = Math.pow(Math.pow(10, activationParams.normalize), 2);
    }
    animation.callback({
        animation: animation,
        inferenceData: inferenceData,
        activationParams: activationParams,
        paramNumber: paramN,
        paramIndex: i,
        metaData: metaData
    });
};

export function getEnabledDefGenerators(props) {
    const { showSortedNeurons, animation } = props;
    let enabledGenerators = { "neurons": true };
    if (showSortedNeurons) {
        enabledGenerators["layer"] = true;
    }
    return enabledGenerators;
}

export function getDefGenerators(enabledGenerators, activations, index, setActivations, activationParams, animation,
    layerActivations, setLayerActivations, ) {
    let defGenerators = [];
    const params = activationParams[index];
    const i = index;
    if ("neurons" in enabledGenerators) {
        const generator = getOutputDefGenerator(
            params, getNeuronInferenceCallback(
                activations, setActivations, i), params.normalize == 0);
        defGenerators.push(generator);
    }
    if ("layer" in enabledGenerators) {
        const generator = getOutputDefGenerator(
            params, getLayerInferenceCallback(
                layerActivations, setLayerActivations, params, i), params.normalize == 0, true, true);
        defGenerators.push(generator);
    }
    if (animation) {
        const generator = getOutputDefGenerator(
            params, getAnimationCallback(
                animation, params, activationParams.length, i),
            params.normalize == 0, true, true);
        defGenerators.push(generator);
    }
    return defGenerators;
}

export function getCanvasWidth(activationParams, neuronWidth) {
    let canvasWidth = 0;
    activationParams.forEach(params => {
        canvasWidth += neuronWidth * params.neurons.length;
    });
    return canvasWidth;
}