import * as tf from '@tensorflow/tfjs'
import {preProcessRawUint8InputImage, postProcessImageTensor} from './ModelUtils'

const NUM_ITERATIONS = 8;

export class PGDAttack {
    constructor(sourceModel, sourceImageData, eps, lp="inf") {
        this.sourceModel = sourceModel;
        const {data, width, height} = sourceImageData;
        this.targetWidth = width;
        this.targetHeight = height;
        const [b, h, w, c] = sourceModel.inputs[0].shape;
        this.sourceImageTensor = this.imagePreProcess(data, width, height, w, h);
        this.logitModel = tf.model({
            inputs:this.sourceModel.inputs, outputs:[this.getLogitOutput()]});
        this.prediction = this.getPrediction();
        console.log(this.prediction);
        this.lp = lp;
        this.attackedTensor = tf.clone(this.sourceImageTensor);
        this.eps = eps;
        this.stepSize = eps/NUM_ITERATIONS * 2;
        this.advFunc = this.getAdv(this.sourceImageTensor, this.lp);

    }

    getPrediction(imageTensor, logits) {
        if(!imageTensor) {
            imageTensor = this.sourceImageTensor;
        }
        if(!logits) {
            logits = this.logitModel.predict(imageTensor);
        }
        const pred = logits.softmax().flatten().dataSync();

        const argArray = Array.from(pred).map((d, i) => [d, i]);
        const sorted = argArray.sort(([a1], [a2]) => a2-a1);
        const data = sorted.map(([d,]) => d);
        const inds = sorted.map(([,i]) => i);

        const origClsProb = this.prediction ? pred[this.prediction.cls] : data[0];
        return {prob:data[0], cls: inds[0], origClsProb: origClsProb};
    }

    getInputShape() {
        return this.sourceModel.inputs[0].shape;
    }

    doStep(eps) {
        this.eps = eps;
        const {attacked, grad} = this.advFunc(this.attackedTensor);
        this.attackedTensor = attacked;
        const pred = this.getPrediction(attacked, null, this.prediction.cls);

        const postProcessed = postProcessImageTensor(this.attackedTensor,
        {width:this.targetWidth, height:this.targetHeight});
        const [b, h, w, c] = postProcessed.shape;
        return {data: new Uint8Array(postProcessed.dataSync()), w:w, h:h,
        startPred: this.prediction, currentPred: pred};
    }

    getCurrent(alpha=1.0) {
        let baseTensor = this.attackedTensor;
        if (alpha !== 1.0) {
            baseTensor = this.attackedTensor.mul(alpha).add(this.sourceImageTensor.mul(1.0-alpha)); 
        }
        const pred = this.getPrediction(baseTensor, null, this.prediction.cls);
        const postProcessed = postProcessImageTensor(baseTensor,
        {width:this.targetWidth, height:this.targetHeight});
        const [b, h, w, c] = postProcessed.shape;
        return {data: new Uint8Array(postProcessed.dataSync()), w:w, h:h,
        startPred: this.prediction, currentPred: pred};
    }

    getAdv(x, lp) {

        const y=this.prediction.cls;

        let adv, body

        const stopGradient = tf.customGrad((x_, save) => {
            save([x_]);
            return {
              value: x_.add(0), // if we'd just return x, the gradient override would not be used
              gradFunc: (dy, saved) => [tf.zerosLike(saved[0])]
            };
          });
        
        if(lp==="2") {
            const normDivisor = (v) => {
                const norm = tf.norm(v, 2);
                return norm;
            }
    
            const l2LinfProject = (v) => {
                const clipped = tf.clipByValue(v, -1, 1);
                const diff = clipped.sub(x);
                const norm = normDivisor(diff);
                const normalized = diff.div(norm).mul(tf.minimum(this.eps, norm));
                return x.add(normalized);
            }
    
            let randomPoint = tf.randomNormal(x.shape);
            randomPoint = randomPoint.div(normDivisor(randomPoint));
            adv = l2LinfProject(x.add(randomPoint.mul(this.eps)));
    
            body = (adv_) => {
                const loss = (x_) => {
                    const logits = this.getLogits(x_);
                    return this.getLossFromLogits(logits, y);
                }
                let grad = tf.grad(loss)(adv_);
                grad = grad.div(normDivisor(grad))
                adv_ = stopGradient(l2LinfProject(adv_.add(grad.mul(this.stepSize))));
                return {attacked:stopGradient(adv_), grad:grad};
            }
        } else if(lp==="inf") {
            const unif = tf.randomUniform(x.shape, -this.eps, this.eps);
            adv = tf.clipByValue(x.add(unif), -1, 1);
            const linfProject = (v) => {
                v = tf.clipByValue(v, -1, 1);
                v = tf.minimum(tf.maximum(v, x.add(-this.eps)), x.add(this.eps));
                return v;
            }

            body = (adv_) => {
                const loss = (x_) => {
                    const logits = this.getLogits(x_);
                    return this.getLossFromLogits(logits, y);
                }
                let grad = tf.grad(loss)(adv_);
                grad = tf.sign(grad);
                adv_ = stopGradient(linfProject(adv_.add(grad.mul(this.stepSize))));
                return {attacked:stopGradient(adv_), grad:grad};
            }
        }

        return body;
    }

    getLossFromLogits(logits, y) {
        const flatLogits = logits.flatten();
        let oneHot = new Int32Array(flatLogits.shape[0]);
        oneHot[y] = 1;
        oneHot = tf.tensor(oneHot);
        const loss = tf.losses.softmaxCrossEntropy(oneHot, flatLogits);
        return loss;
    }

    getLogits(x) {
        return this.logitModel.predict(x);
    }

    getLogitOutput() {
        const logits = this.sourceModel.outputs[0].inputs[0];
        return logits;
    }

    imagePreProcess(flatData, inWidth, inHeight, modelInWidth, modelInHeight) {
        return preProcessRawUint8InputImage(
            flatData, {w:inWidth, h:inHeight}, {w:modelInWidth, h:modelInHeight}, {min:0, max:255}, {min:-1, max:1});
    }
}