import { LogisticRegressionSettings } from "../../ReteControls/Models/LogisticRegressionSettings";
import { DataOutput, PreviousNodeData } from "../../ReteInterfaces/SharedInterfaces";
import { ClassicPreset } from "rete";

export class LogisticRegressionModel extends ClassicPreset.Node<
    {},
    { output: ClassicPreset.Socket },
    { settings: LogisticRegressionSettings }
> {
    width = 250;
    height = 290;
    model = "logistic_regression";
    solver = "liblinear";
    penalty = "l2";
    C = 1.0;
    max_iter = 100;
    l1_ratio = 0.5;
    requiredImports = {
        libraryName: "Scikit-Learn",
        importStatement: "from sklearn.linear_model import LogisticRegression",
    };
    update: (type: "node", asset: LogisticRegressionModel) => void;
    change: () => void;
    constructor(
        socket: ClassicPreset.Socket,
        update: (type: "node", asset: LogisticRegressionModel) => void,
        change: () => void
    ) {
        super("Logistic Regression Model");
        this.addControl(
            "settings",
            new LogisticRegressionSettings(
                this.solver,
                this.penalty,
                this.C,
                this.max_iter,
                this.l1_ratio,
                this.updateL1Ratio.bind(this),
                this.updateSolver.bind(this),
                this.updatePenalty.bind(this),
                this.updateC.bind(this),
                this.updateMaxIter.bind(this)
            )
        );
        this.addOutput("output", new ClassicPreset.Output(socket, "Untrained Model"));
        this.update = update;
        this.change = change;
    }
    getType() {
        return "LogisticRegressionModel";
    }
    getDisplayName() {
        return "Logistic Regression Model";
    }
    updateL1Ratio = (value: number) => {
        this.l1_ratio = value;
        this.controls.settings.initialL1Ratio = value;
        this.change();
    };

    updateSolver = (value: string) => {
        this.solver = value;
        this.controls.settings.initialSolver = value;
        this.change();
    };

    updatePenalty = (value: string) => {
        this.penalty = value;
        this.controls.settings.initialPenalty = value;
        if (value === "elasticnet") {
            this.height = 335;
        } else {
            this.height = 290;
        }
        this.update("node", this);
        this.change();
    };

    updateC = (value: number) => {
        this.C = value;
        this.controls.settings.initialC = value;
        this.change();
    };

    updateMaxIter = (value: number) => {
        this.max_iter = value;
        this.controls.settings.initialMaxIter = value;
        this.change();
    };

    getCurrentCode(): string {
        const currentCode = `model = LogisticRegression(solver='${this.solver}', penalty='${this.penalty}', C=${this.C}, max_iter=${this.max_iter})\n`;
        return currentCode;
    }

    data(inputs: {}): { output: DataOutput } {
        const currentCode = this.getCurrentCode();
        const previousNodes: PreviousNodeData[] = [
            { ...this, connectedPort: this.outputs.output! },
        ];
        const output: DataOutput = {
            code: [currentCode],
            previousNodes,
        };

        return {
            output,
        };
    }
}
