import { TestTrainSplitControl } from "../ReteControls/TestTrainSplitControls";
import { combineCodeSnippets } from "../../helpers/pyodideHelpers";
import { ClassicPreset } from "rete";
import { DataOutput, PreviousNodeData } from "../ReteInterfaces/SharedInterfaces";
import {
    arraysEqual,
    delayFunction,
    delayMs,
    isEmpty,
    loadCollectedImportStatements,
} from "../../helpers/nodeHelpers";
import { DatasetImporter } from "./DatasetImporter";
import { PackageImporter } from "./PackageImporter";
import { ProblemContextType } from "../../contexts/ProblemContext";

export class TestTrainSplit extends ClassicPreset.Node<
    { dataframe_input: ClassicPreset.Socket },
    { train_output: ClassicPreset.Socket; test_output: ClassicPreset.Socket },
    { test_train_split: TestTrainSplitControl }
> {
    width = 375;
    height = 365;
    dataframeColumns: string[] = [];
    selectedFeatures: string[] = [];
    targetVariable: string = "Choose Variable";
    testSize: number = 0.2;
    errorMessages: string[] = ["No dataframe found", "Sci-Kit Learn must be imported"];
    private update: (type: "node", node: this) => void;
    private change: () => void;
    private debouncedUpdate: () => void;
    constructor(
        socket: ClassicPreset.Socket,
        update: (type: "node" | "connection", node: TestTrainSplit) => void,
        change: () => void,
        context: ProblemContextType
    ) {
        super("Test Train Split");
        this.addInput("dataframe_input", new ClassicPreset.Input(socket, "Dataframe"));
        this.addOutput("train_output", new ClassicPreset.Output(socket, "Train Dataframe"));
        this.addOutput("test_output", new ClassicPreset.Output(socket, "Test Dataframe"));
        this.addControl(
            "test_train_split",
            new TestTrainSplitControl(
                this.dataframeColumns,
                this.errorMessages,
                this.selectedFeatures,
                this.targetVariable,
                this.testSize,
                false,
                this.setSelectedFeatures.bind(this),
                this.setTargetVariable.bind(this),
                this.setTestSize.bind(this)
            )
        );
        this.update = update;
        this.change = change;
        this.debouncedUpdate = delayFunction(this.change.bind(this), 1000);
    }
    getType() {
        return "TestTrainSplit";
    }
    getDisplayName() {
        return "Test-Train-Split";
    }
    setSelectedFeatures(selectedFeatures: string[]) {
        this.selectedFeatures = [...selectedFeatures];
        this.controls.test_train_split._setSelectedFeatures(selectedFeatures);
        this.update("node", this);
        this.debouncedUpdate();
    }

    setTargetVariable(targetVariable: string) {
        this.targetVariable = targetVariable;
        this.controls.test_train_split.setTargetFeature(targetVariable);
        this.update("node", this);
        this.debouncedUpdate();
    }

    setTestSize(testSize: number) {
        this.testSize = testSize;
        this.controls.test_train_split._setTestSize(testSize);
        this.update("node", this);
        this.debouncedUpdate();
    }

    #setIsLoading(loadingStatus: boolean) {
        this.controls.test_train_split.setIsLoading(loadingStatus);
        this.update("node", this);
    }

    #clearErrorMessages() {
        this.controls.test_train_split.setErrorMessages([]);
        this.update("node", this);
    }

    #getDatasetColumns(previousNodes: PreviousNodeData[]): string[] {
        const datasetImporterNode: DatasetImporter | undefined = previousNodes.find(
            (node): node is DatasetImporter & PreviousNodeData => {
                return node.label === "Import Dataset";
            }
        );
        const columns: string[] =
            datasetImporterNode?.datasetColumns.filter((column) => column !== "id") || [];
        return columns;
    }

    #dataframeExists(code: string): boolean {
        const regex = /df\s*=\s*pd/g;
        const dataframe = code.match(regex);
        return !!dataframe;
    }

    #resetSelectedFeatures(columns: string[]) {
        this.dataframeColumns = columns.filter((column) => column !== "id");
        this.selectedFeatures = [];
        this.targetVariable = "";
        this.testSize = 0.2;
        this.controls.test_train_split._setTestSize(this.testSize);
        this.controls.test_train_split.setDataframeColumns(this.dataframeColumns);
        this.controls.test_train_split._setSelectedFeatures([]);
        this.controls.test_train_split.setTargetFeature("");
        this.update("node", this);
    }

    #getUnselectedFeatures(): string[] {
        const unselectedFeatures = this.dataframeColumns
            .filter((column) => !this.selectedFeatures.includes(column))
            .map((column) => `"${column}"`);
        return unselectedFeatures;
    }

    #generateCode(): string {
        const unselectedFeatures = this.#getUnselectedFeatures();
        // # Features dataframe - remove unwanted columns
        let code = `X = df.drop(columns=[
    ${
        !isEmpty(unselectedFeatures)
            ? this.#getUnselectedFeatures().join(", ")
            : "# Every column selected, nothing to drop"
    }
])
${
    !!this.targetVariable
        ? `y = df["${this.targetVariable}"]`
        : "# Select a target variable to continue"
}\n`;

        if (!isEmpty(this.selectedFeatures) && !!this.targetVariable) {
            code += `X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=${this.testSize}, random_state=42)\n`;
        }
        return code;
    }

    #addRequiredImport(previousNodes: PreviousNodeData[]) {
        const libraryName = "Scikit-Learn";
        const importStatement = "from sklearn.model_selection import train_test_split";
        const packageImporterNode: PackageImporter | undefined = previousNodes.find(
            (node): node is PackageImporter & PreviousNodeData => {
                return node.label === "Import Libraries";
            }
        );
        const ImportManager = packageImporterNode?.ImportManager;
        if (ImportManager) {
            ImportManager.addSpecificImport(libraryName, importStatement);
        }
    }

    #hasBaseImport(previousNodes: PreviousNodeData[]): boolean {
        const packageImporterNode: PackageImporter | undefined = previousNodes.find(
            (node): node is PackageImporter & PreviousNodeData => {
                return node.label === "Import Libraries";
            }
        );
        const hasBaseImport = packageImporterNode?.ImportManager.currentImports.some(
            (i) => i.libraryName === "Scikit-Learn"
        );
        return !!hasBaseImport;
    }

    async data(inputs: {
        dataframe_input: DataOutput[];
    }): Promise<{ train_output: DataOutput; test_output: DataOutput }> {
        this.#setIsLoading(true);
        const previousNodes = inputs.dataframe_input?.[0].previousNodes || [];
        this.#addRequiredImport(previousNodes);
        const importStatement: string = loadCollectedImportStatements(previousNodes);
        const previousCode = inputs.dataframe_input?.[0].code || [];
        if (inputs.dataframe_input && importStatement) {
            previousCode[0] = importStatement;
        }
        let totalCode = combineCodeSnippets(previousCode);
        let errorComments: { code: string; msg: string }[] = [];
        let currentNodeTrain = { ...this, connectedPort: this.outputs.train_output! };
        let currentNodeTest = { ...this, connectedPort: this.outputs.test_output! };
        const dataframeExists = this.#dataframeExists(totalCode);
        if (!dataframeExists) {
            errorComments.push({
                code: "# Test Train Split: No dataframe found",
                msg: "No dataframe found",
            });
        }
        const hasRequiredBaseImport = this.#hasBaseImport(previousNodes);
        if (!hasRequiredBaseImport) {
            errorComments.push({
                code: "# scikit-learn must be imported",
                msg: "scikit-learn must be imported",
            });
        }
        if (errorComments.length > 0) {
            this.#setIsLoading(false);
            this.#resetSelectedFeatures([]);
            const errorCode = errorComments.map((e) => e.code).join("\n");
            this.controls.test_train_split.setErrorMessages(errorComments.map((e) => e.msg));
            return {
                train_output: {
                    code: [...previousCode, errorCode],
                    previousNodes: [...previousNodes, currentNodeTrain],
                },
                test_output: {
                    code: [...previousCode, errorCode],
                    previousNodes: [...previousNodes, currentNodeTest],
                },
            };
        }
        this.#clearErrorMessages();
        const dataframeColumns = this.#getDatasetColumns(previousNodes);
        const areColumnsChanged = !arraysEqual(dataframeColumns, this.dataframeColumns);
        if (areColumnsChanged) {
            this.#resetSelectedFeatures(dataframeColumns);
        }
        const currentCode = this.#generateCode();

        currentNodeTrain = { ...this, connectedPort: this.outputs.train_output! };
        currentNodeTest = { ...this, connectedPort: this.outputs.test_output! };
        const train_output: DataOutput = {
            code: [...previousCode, currentCode],
            previousNodes: [...previousNodes, currentNodeTrain],
        };
        const test_output: DataOutput = {
            code: [...previousCode, currentCode],
            previousNodes: [...previousNodes, currentNodeTest],
        };
        this.#setIsLoading(false);
        return { train_output, test_output };
    }
}
