import { DataOutput, PreviousNodeData } from "components/ReteInterfaces/SharedInterfaces";
import { ClassicPreset } from "rete";
import {
    loadCollectedImportStatements,
    mergeCommonAncestorCodePaths,
    mergeCommonAncestorNodePaths,
} from "../../helpers/nodeHelpers";

export class PredictModel extends ClassicPreset.Node<
    { trained_model: ClassicPreset.Socket; test_dataset: ClassicPreset.Socket },
    { predictions: ClassicPreset.Socket },
    {}
> {
    width = 240;
    height = 150;
    constructor(socket: ClassicPreset.Socket) {
        super("Predict Model");
        this.addInput("trained_model", new ClassicPreset.Input(socket, "Trained Model"));
        this.addInput("test_dataset", new ClassicPreset.Input(socket, "Test Dataframe"));
        this.addOutput("predictions", new ClassicPreset.Output(socket, "Predictions"));
    }
    getType() {
        return "PredictModel";
    }
    getDisplayName() {
        return "Predict Model";
    }
    #checkConnections(
        trained_model: DataOutput | null,
        test_dataset: DataOutput | null
    ): {
        allValidConnections: boolean;
        codeMessageIfInvalid: string;
        previousCodeIfError: string[];
        previousNodesIfError: PreviousNodeData[];
    } {
        // Check if model and training dataset are connected
        let allValidConnections = true;
        let codeMessageIfInvalid = "";
        let previousCodeIfError: string[] = [];
        let previousNodesIfError: PreviousNodeData[] = [];
        if (!trained_model) {
            allValidConnections = false;
            codeMessageIfInvalid +=
                "# Connect a trained model to the `Trained Model` port on the Predict Model block\n";
        } else {
            previousCodeIfError = [...trained_model.code];
            previousNodesIfError = [...trained_model.previousNodes];
            const hasTrainedModel = trained_model.previousNodes.find(
                (node) => node.connectedPort.label === "Trained Model"
            );
            if (!hasTrainedModel) {
                allValidConnections = false;
                codeMessageIfInvalid +=
                    "# Invalid connection at `Trained Model` port on the Predict Model block\n";
            }
        }
        if (!test_dataset) {
            allValidConnections = false;
            codeMessageIfInvalid +=
                "# Connect a test dataset to the `Test Dataset` port on the Predict Model block\n";
        } else {
            previousNodesIfError = mergeCommonAncestorNodePaths(
                previousNodesIfError,
                test_dataset.previousNodes
            );
            const importStatement: string = loadCollectedImportStatements(previousNodesIfError);
            previousCodeIfError = mergeCommonAncestorCodePaths(
                test_dataset.code,
                previousCodeIfError,
                importStatement
            );
            const previousNodesTestPort = test_dataset.previousNodes;
            const hasTestDataset = previousNodesTestPort.find((node) => {
                return node.connectedPort.label === "Test Dataset";
            });
            if (!hasTestDataset) {
                allValidConnections = false;
                codeMessageIfInvalid +=
                    "# Invalid connection at `Test Dataframe` port on the Predict Model block\n";
            }
        }

        return {
            allValidConnections,
            codeMessageIfInvalid,
            previousCodeIfError,
            previousNodesIfError,
        };
    }

    data(inputs: { trained_model?: DataOutput[]; test_dataset?: DataOutput[] }): {
        predictions: DataOutput;
    } {
        let currentCode = "";
        let previousNodes: PreviousNodeData[] = [];
        const trainedModelInput = inputs.trained_model?.[0] || null;
        const testDataset = inputs.test_dataset?.[0] || null;

        const {
            allValidConnections,
            codeMessageIfInvalid,
            previousCodeIfError,
            previousNodesIfError,
        } = this.#checkConnections(trainedModelInput, testDataset);

        if (!allValidConnections) {
            currentCode += codeMessageIfInvalid;
            return {
                predictions: {
                    code: [...previousCodeIfError, currentCode],
                    previousNodes: [
                        ...previousNodesIfError,
                        { ...this, connectedPort: this.outputs.predictions! },
                    ],
                },
            };
        }

        previousNodes = mergeCommonAncestorNodePaths(
            trainedModelInput!.previousNodes,
            testDataset!.previousNodes
        );
        const importStatement: string = loadCollectedImportStatements(previousNodes);
        const previousCode = mergeCommonAncestorCodePaths(
            trainedModelInput!.code,
            testDataset!.code,
            importStatement
        );
        currentCode += `predictions = model.predict(X_test)\n`;

        const currentNode: PreviousNodeData = {
            ...this,
            connectedPort: this.outputs.predictions!,
        };

        const predictions: DataOutput = {
            code: [...previousCode, currentCode],
            previousNodes: [...previousNodes, currentNode],
        };
        return { predictions };
    }
}
