import { DataOutput, PreviousNodeData } from "components/ReteInterfaces/SharedInterfaces";
import { ClassicPreset } from "rete";
import { LinearRegressionModel } from "./Models/LinearRegressionModel";
import { PackageImporter } from "./PackageImporter";
import {
    findModelType,
    loadCollectedImportStatements,
    mergeCommonAncestorCodePaths,
    mergeCommonAncestorNodePaths,
} from "../../helpers/nodeHelpers";

export class ScoreModelBasic extends ClassicPreset.Node<
    { trained_model: ClassicPreset.Socket; test_dataset: ClassicPreset.Socket },
    { simple_score: ClassicPreset.Socket },
    {}
> {
    width = 240;
    height = 150;
    public score = "";
    constructor(socket: ClassicPreset.Socket) {
        super("Simple Score Model");
        this.addInput("trained_model", new ClassicPreset.Input(socket, "Trained Model"));
        this.addInput("test_dataset", new ClassicPreset.Input(socket, "Test Dataframe"));
        this.addOutput("simple_score", new ClassicPreset.Output(socket, "Score"));
    }
    getType() {
        return "ScoreModelBasic";
    }
    getDisplayName() {
        return "Score Model Basic";
    }
    private scoreVariableNames = {
        "linear_regression": "R_squared",
        "logistic_regression": "Accuracy",
        "decision_tree_classifier": "Accuracy",
        "random_forest_classifier": "Accuracy",
        "support_vector_classifier": "Accuracy",
        "k_nearest_neighbors": "Accuracy",
        "decision_tree_regressor": "R_squared",
        "random_forest_regressor": "R_squared",
        "support_vector_regressor": "R_squared",
    };

    #checkConnections(
        trained_model: DataOutput | null,
        test_dataset: DataOutput | null
    ): {
        allValidConnections: boolean;
        codeMessageIfInvalid: string;
        previousCodeIfError: string[];
        previousNodesIfError: PreviousNodeData[];
    } {
        let allValidConnections = true;
        let codeMessageIfInvalid = "";
        let previousCodeIfError: string[] = [];
        let previousNodesIfError: PreviousNodeData[] = [];
        if (!trained_model) {
            allValidConnections = false;
            codeMessageIfInvalid +=
                "# Connect a trained model to the `Trained Model` port on the Simple Score Model block\n";
        } else {
            previousCodeIfError = [...trained_model.code];
            previousNodesIfError = [...trained_model.previousNodes];
            const hasTrainedModel = trained_model.previousNodes.find(
                (node) => node.connectedPort.label === "Trained Model"
            );
            if (!hasTrainedModel) {
                allValidConnections = false;
                codeMessageIfInvalid +=
                    "# Invalid connection at `Trained Model` port on the Simple Score Model block\n";
            }
        }
        if (!test_dataset) {
            allValidConnections = false;
            codeMessageIfInvalid +=
                "# Connect a test dataset to the `Test Dataset` port on the Simple Score Model block\n";
        } else {
            previousNodesIfError = mergeCommonAncestorNodePaths(
                previousNodesIfError,
                test_dataset.previousNodes
            );
            const importStatement: string = loadCollectedImportStatements(previousNodesIfError);
            previousCodeIfError = mergeCommonAncestorCodePaths(
                test_dataset.code,
                previousCodeIfError,
                importStatement
            );
            const previousNodesTestPort = test_dataset.previousNodes;
            const hasTestDataset = previousNodesTestPort.find((node) => {
                return node.connectedPort.label === "Test Dataframe";
            });
            if (!hasTestDataset) {
                allValidConnections = false;
                codeMessageIfInvalid +=
                    "# Invalid connection at `Test Dataframe` port on the Simple Score Model block\n";
            }
        }

        return {
            allValidConnections,
            codeMessageIfInvalid,
            previousCodeIfError,
            previousNodesIfError,
        };
    }

    #getScoreVariableBasedOnModelType(previousNodes: PreviousNodeData[]) {
        let scoreVariable = "unknown_score";
        const modelType = findModelType(previousNodes);
        if (!modelType) {
            return scoreVariable;
        }
        scoreVariable = this.scoreVariableNames[modelType] || "Score";
        return scoreVariable;
    }

    data(inputs: { trained_model?: DataOutput[]; test_dataset?: DataOutput[] }): {
        simple_score: DataOutput;
    } {
        let currentCode = "";
        let previousNodes: PreviousNodeData[] = [];
        const trainedModelInput = inputs.trained_model?.[0] || null;
        const testDataset = inputs.test_dataset?.[0] || null;

        const {
            allValidConnections,
            codeMessageIfInvalid,
            previousCodeIfError,
            previousNodesIfError,
        } = this.#checkConnections(trainedModelInput, testDataset);

        if (!allValidConnections) {
            currentCode += codeMessageIfInvalid;
            return {
                simple_score: {
                    code: [...previousCodeIfError, currentCode],
                    previousNodes: [
                        ...previousNodesIfError,
                        { ...this, connectedPort: this.outputs.simple_score! },
                    ],
                },
            };
        }

        previousNodes = mergeCommonAncestorNodePaths(
            trainedModelInput!.previousNodes,
            testDataset!.previousNodes
        );
        const importStatement: string = loadCollectedImportStatements(previousNodes);
        const previousCode = mergeCommonAncestorCodePaths(
            trainedModelInput!.code,
            testDataset!.code,
            importStatement
        );
        const scoreVariable = this.#getScoreVariableBasedOnModelType(
            trainedModelInput!.previousNodes
        );
        currentCode += `${scoreVariable} = model.score(X_test, y_test)\n`;
        currentCode += `print("${scoreVariable} is", ${scoreVariable})\n`;

        const currentNode: PreviousNodeData = {
            ...this,
            connectedPort: this.outputs.simple_score!,
        };

        const simple_score: DataOutput = {
            code: [...previousCode, currentCode],
            previousNodes: [...previousNodes, currentNode],
        };
        return { simple_score };
    }
}
