import { ClassicPreset } from "rete";
import { DataOutput, PreviousNodeData } from "../ReteInterfaces/SharedInterfaces";
import { VisualizerControls } from "../ReteControls/VisualizerControls";
import { PackageImporter } from "./PackageImporter";
import { arraysEqual, loadCollectedImportStatements } from "../../helpers/nodeHelpers";
import { Dispatch, SetStateAction } from "react";
import { DatasetImporter } from "./DatasetImporter";

export class Visualizer extends ClassicPreset.Node<
    { data: ClassicPreset.Socket },
    { output: ClassicPreset.Socket },
    { controls: VisualizerControls }
> {
    width = 300;
    height = 250;
    plotType: string = "";
    categoricalColumn: string = "";
    numericalColumns: string[] = [];
    selectedSpecies: string[] = [];
    bins: number = 10;
    messages: string[] = ["No dataframe found"];

    constructor(
        socket: ClassicPreset.Socket,
        private update: (type: "node", asset: Visualizer) => void,
        setIsVisualizerOpen: Dispatch<SetStateAction<boolean>>,
        private change: () => void
    ) {
        super("Generate Plot");
        this.addControl(
            "controls",
            new VisualizerControls(
                () => setIsVisualizerOpen(true),
                this.setPlotType.bind(this),
                this.setCategoricalColumn.bind(this),
                this.setNumericalColumns.bind(this),
                this.setBins.bind(this),
                this.adjustHeight.bind(this),
                this.messages
            )
        );
        this.addInput("data", new ClassicPreset.Input(socket, "Data"));
    }

    getType() {
        return "Visualizer";
    }
    getDisplayName() {
        return "Visualizer";
    }
    setPlotType(type: string) {
        this.plotType = type;
        this.controls.controls._setPlotType(type);
        this.categoricalColumn = "";
        this.controls.controls._setCategoricalColumn("");
        this.numericalColumns = [];
        this.controls.controls._setNumericalColumns([]);
        this.controls.controls.setBins(10);
        this.change();
    }

    setCategoricalColumn(column: string) {
        this.categoricalColumn = column;
        this.controls.controls._setCategoricalColumn(column);
        this.change();
    }

    setNumericalColumns(columns: string[]) {
        this.numericalColumns = columns;
        this.controls.controls._setNumericalColumns(columns);
        this.change();
    }

    setBins(bins: number) {
        this.bins = bins;
        this.controls.controls._setBins(bins);
        this.change();
    }

    adjustHeight(value: number) {
        this.height = value;
        this.update("node", this);
    }

    #setDfColumns(columns: string[]) {
        if (!arraysEqual(columns, this.controls.controls.df_columns)) {
            this.controls.controls.setDfColumns(columns);
            this.numericalColumns = [];
            this.categoricalColumn = "";
            this.update("node", this);
        }
    }

    #addRequiredImport(previousNodes: PreviousNodeData[]) {
        const libraryNames = ["Matplotlib", "Seaborn"];
        const packageImporterNode: PackageImporter | undefined = previousNodes.find(
            (node): node is PackageImporter & PreviousNodeData => {
                return node.label === "Import Libraries";
            }
        );
        const ImportManager = packageImporterNode?.ImportManager;
        if (ImportManager) {
            for (const libraryName of libraryNames) {
                ImportManager.addBaseImport(libraryName);
            }
        }
    }

    #getCurrentCode() {
        switch (this.plotType) {
            case "heatmap":
                return `corr = df.corr()
sns.heatmap(corr, annot=True)
plt.show()`;
            case "bar_chart":
                if (this.numericalColumns.length === 0) {
                    return `df['${
                        this.categoricalColumn || "Select categorical column"
                    }'].value_counts().plot(kind='bar')
plt.xlabel('${this.categoricalColumn || "Select categorical column"}')
plt.ylabel('Counts')
plt.title('Frequency of Each ${this.categoricalColumn || "Select categorical column"}')
plt.xticks(rotation=45)
plt.subplots_adjust(bottom=0.25)
plt.show()`;
                } else {
                    return `df.groupby('${
                        this.categoricalColumn || "Select categorical column"
                    }')['${this.numericalColumns[0]}'].mean().plot(kind='bar')
plt.xlabel('${this.categoricalColumn || "Select categorical column"}')
plt.ylabel('Average of ${this.numericalColumns[0]}')
plt.title('Average ${this.numericalColumns[0]} Grouped by ${
                        this.categoricalColumn || "Select categorical column"
                    }')
plt.xticks(rotation=45)
plt.subplots_adjust(bottom=0.25)
plt.show()`;
                }
            case "histogram":
                return `df['${
                    this.numericalColumns[0] || "Select column"
                }'].plot(kind='hist', bins=${this.bins})
plt.xlabel('${this.numericalColumns[0] || "Select column"}')
plt.ylabel('Frequency')
plt.title('Histogram of ${this.numericalColumns[0] || "Select column"} (${this.bins} bins)')
plt.show()`;
            case "scatter":
                return `sns.scatterplot(data=df, x='${
                    this.numericalColumns[0] || "Select x column"
                }', y='${this.numericalColumns[1] || "Select y column"}'${
                    this.categoricalColumn &&
                    `, hue='${this.categoricalColumn || "Select categorical column"}'`
                })
plt.xlabel('${this.numericalColumns[0] || "Select x column"}')
plt.ylabel('${this.numericalColumns[1] || "Select y column"}')
plt.title('Scatter Plot of "${this.numericalColumns[0] || "Select x column"}" vs "${
                    this.numericalColumns[1] || "Select y column"
                }"')
plt.show()`;
            case "line":
                return `df.plot()
plt.show()`;
            default:
                return ``;
        }
    }

    #getDatasetColumns(previousNodes: PreviousNodeData[]): string[] {
        const datasetImporterNode: DatasetImporter | undefined = previousNodes.find(
            (node): node is DatasetImporter & PreviousNodeData => {
                return node.label === "Import Dataset";
            }
        );
        const columns: string[] =
            datasetImporterNode?.datasetColumns.filter((column) => column !== "id") || [];

        return columns;
    }

    #dataframeExists(code: string): boolean {
        const regex = /df\s*=\s*pd/g;
        const dataframe = code.match(regex);
        return !!dataframe;
    }

    #checkForRequiredImports(prevNodes: PreviousNodeData[], previousCode: string[]) {
        const dataframeExists = this.#dataframeExists(previousCode.join("\n"));
        const packageImporterNode = prevNodes.find((node) => node.label === "Import Libraries") as
            | PackageImporter
            | undefined;
        const isMatPlotLibImported = !!packageImporterNode?.selectedItems
            .map((item: string) => item.toLowerCase())
            .includes("matplotlib");
        const isSeabornImported = !!packageImporterNode?.selectedItems
            .map((item: string) => item.toLowerCase())
            .includes("seaborn");

        const display = this.controls.controls;
        display.clearMessages();
        if (!dataframeExists) {
            display.addMessage("No dataframe found");
            return;
        }
        if (!isMatPlotLibImported) {
            display.addMessage("Import Matplotlib");
        }
        switch (this.plotType) {
            case "heatmap":
            case "scatter":
                if (!isSeabornImported) {
                    display.addMessage("Import Seaborn");
                }
                break;

            case "bar_chart":
            case "histogram":
                break;

            default:
                break;
        }
    }

    data(inputs: { data: DataOutput[] }): { output: DataOutput } {
        const previousNodes = inputs.data?.[0].previousNodes || [];
        const importStatement: string = loadCollectedImportStatements(previousNodes);
        const previousCode = inputs.data?.[0].code || [];
        if (inputs.data && importStatement) {
            previousCode[0] = importStatement;
        }
        const df_columns = this.#getDatasetColumns(previousNodes);
        this.#setDfColumns(df_columns);
        let currentCode = this.#getCurrentCode();
        if (currentCode) {
            currentCode = "plt.clf() # Clear any previous plot data\n" + currentCode;
        }
        this.#checkForRequiredImports(previousNodes, previousCode);
        const output: DataOutput = {
            code: [...previousCode, currentCode],
            previousNodes: [...previousNodes, { ...this, connectedPort: this.outputs.output! }],
        };
        this.update("node", this);
        return {
            output,
        };
    }
}
