import { DataOutput, PreviousNodeData } from "components/ReteInterfaces/SharedInterfaces";
import { ClassicPreset } from "rete";
import { LinearRegressionModel } from "./Models/LinearRegressionModel";
import { PackageImporter } from "./PackageImporter";
import { loadCollectedImportStatements } from "../../helpers/nodeHelpers";

export class TrainModel extends ClassicPreset.Node<
    { model_input: ClassicPreset.Socket; training_dataset_input: ClassicPreset.Socket },
    { trained_model_output: ClassicPreset.Socket },
    {}
> {
    width = 240;
    height = 150;
    constructor(socket: ClassicPreset.Socket) {
        super("Train Model");
        this.addInput("model_input", new ClassicPreset.Input(socket, "Untrained Model"));
        this.addInput("training_dataset_input", new ClassicPreset.Input(socket, "Train Dataframe"));
        this.addOutput("trained_model_output", new ClassicPreset.Output(socket, "Trained Model"));
    }
    getType() {
        return "TrainModel";
    }
    getDisplayName() {
        return "Train Model";
    }
    #checkConnections(
        model_input: DataOutput | null,
        training_dataset_input: DataOutput | null
    ): {
        allValidConnections: boolean;
        codeMessageIfInvalid: string;
        previousCodeIfError: string[];
        previousNodesIfError: PreviousNodeData[];
    } {
        // Check if model and training dataset are connected
        let allValidConnections = true;
        let codeMessageIfInvalid = "";
        let previousCodeIfError: string[] = [];
        let previousNodesIfError: PreviousNodeData[] = [];
        if (!model_input) {
            allValidConnections = false;
            codeMessageIfInvalid += "# Connect a model to Train Model block\n";
        } else {
            previousCodeIfError = model_input.code;
            previousNodesIfError = [...model_input.previousNodes];
            const model = model_input.previousNodes[0];
            if (model && !("model" in model)) {
                allValidConnections = false;
                codeMessageIfInvalid += "# Connect a model to Train Model block\n";
            }
        }
        if (!training_dataset_input) {
            allValidConnections = false;
            codeMessageIfInvalid += "# Connect a training dataset to Train Model block\n";
        } else {
            previousCodeIfError = [...training_dataset_input.code, ...previousCodeIfError];
            previousNodesIfError = [
                ...training_dataset_input.previousNodes,
                ...previousNodesIfError,
            ];
            const previousNodesTrainPort = training_dataset_input.previousNodes;
            const noTrainDatasetFound = previousNodesTrainPort.every((node) => {
                return !(node.connectedPort.label === "Train Dataframe");
            });
            if (noTrainDatasetFound) {
                allValidConnections = false;
                codeMessageIfInvalid += "# Connect a training dataset to Train Model block\n";
            }
        }

        return {
            allValidConnections,
            codeMessageIfInvalid,
            previousCodeIfError,
            previousNodesIfError,
        };
    }

    #addRequiredModelImports(modelInput: DataOutput, trainingDatasetInput: DataOutput) {
        const previousModelNodes = modelInput.previousNodes;
        const previousTrainingDatasetNodes = trainingDatasetInput.previousNodes;
        const model = previousModelNodes.find((node) => "model" in node) as
            | LinearRegressionModel
            | undefined;
        const requiredImports = model?.requiredImports || null;
        if (requiredImports) {
            const { libraryName, importStatement } = requiredImports;
            const packageImporterNode: PackageImporter | undefined =
                previousTrainingDatasetNodes.find(
                    (node): node is PackageImporter & PreviousNodeData => {
                        return node.label === "Import Libraries";
                    }
                );
            const ImportManager = packageImporterNode?.ImportManager || null;
            if (ImportManager) {
                ImportManager.addSpecificImport(libraryName, importStatement);
            }
        }
    }

    data(inputs: { model_input?: DataOutput[]; training_dataset_input?: DataOutput[] }): {
        trained_model_output: DataOutput;
    } {
        let currentCode = "";
        let previousNodes: PreviousNodeData[] = [];
        const modelInput = inputs.model_input?.[0] || null;
        const trainingDataset = inputs.training_dataset_input?.[0] || null;

        const {
            allValidConnections,
            codeMessageIfInvalid,
            previousCodeIfError,
            previousNodesIfError,
        } = this.#checkConnections(modelInput, trainingDataset);

        if (!allValidConnections) {
            currentCode += codeMessageIfInvalid;
            return {
                trained_model_output: {
                    code: [...previousCodeIfError, currentCode],
                    previousNodes: [
                        ...previousNodesIfError,
                        { ...this, connectedPort: this.outputs.trained_model_output! },
                    ],
                },
            };
        }

        this.#addRequiredModelImports(modelInput!, trainingDataset!);

        previousNodes = [...trainingDataset!.previousNodes, ...modelInput!.previousNodes];
        const importStatement: string = loadCollectedImportStatements(previousNodes);
        const previousCode = [...trainingDataset!.code, ...modelInput!.code];
        if (importStatement) {
            previousCode[0] = importStatement;
        }
        currentCode += `model.fit(X_train, y_train)\n`;
        const currentNode: PreviousNodeData = {
            ...this,
            connectedPort: this.outputs.trained_model_output!,
        };
        const trained_model_output: DataOutput = {
            code: [...previousCode, currentCode],
            previousNodes: [...previousNodes, currentNode],
        };
        return { trained_model_output };
    }
}
