import {Model} from './Model'
import * as tf from '@tensorflow/tfjs'
import {getArchFromDirName, getModelPathFromDirName} from './ModelMetaInfo'

function tempModifyWeights (model, layerName) {
    const l = model.sourceModel.getLayer(layerName);
    const [kernel_wts, biasTensor] = l.getWeights();
    const shp = kernel_wts.shape;
    const wtData = kernel_wts.dataSync();
    let override = new Float32Array(4000);
    override = tf.tensor(override).add(0).dataSync();
    wtData.set(override);
    const newKernelWtTensor = tf.tensor(wtData, shp);
    l.setWeights([newKernelWtTensor, biasTensor]);
    console.log("modified weights of layer "+l.name);
}

export class ModelManager {
    constructor(capacity){
        this.models = []
        for(let i=0; i<capacity; i++) {
            this.models.push(null);
        }
        this.capacity = capacity;
    }

    setModelDirs(modelDirs) {
        this.modelDirs = [modelDirs.model1, modelDirs.model2];
        this.modelDirs.forEach((modelDir, i) => {
            if(!this.models[i] || modelDir !== this.models[i].modelName) {
                if (this.models[i]) {
                    this.models[i].dispose()
                }
                this.setModel(i, ()=>{}, false);
            }
        });
    }

    getModel(index, forceInit=false) {
        const dirName = this.modelDirs[index];
        if(!this.models[index] || !(dirName === this.models[index].modelName)) {
            this.setModel(index, ()=>{}, forceInit);
        }
        else if(!this.models[index].isLoaded() && forceInit) {
            this.models[index].init(()=>{});
        }
        return this.models[index];
    }

    setModel(index, callback, init) {
        const dirName = this.modelDirs[index];
        if(index > this.capacity-1) {
            throw "invalid index!";
        }
        const archType = getArchFromDirName(dirName);
        const modelPath = getModelPathFromDirName(dirName);
        this.models[index] = new Model(archType, modelPath);
        if (init) {
            this.modelDirs[index].init(callback);
        }
    }
}