import {useRef, useEffect, FC, useState} from "react";
import ForceGraph2D from "react-force-graph-2d";
import {ForceGraph2DProps} from "react-force-graph";
import { Box } from "@mui/material";

// Define node and link types
interface GraphNode {
    id: string;
    name: string;
    group: number;
    x: number;
    y: number;
    color?: string;
    icon?: string;
}

interface Link {
    source: GraphNode;
    target: GraphNode;
    value: string;
}

const Graph: FC<ForceGraph2DProps> = ({graphData}) => {
    const graphRef = useRef<any>(null);
    const [highlightLinks, setHighlightLinks] = useState(new Set<Link>());
    const [iconCache, setIconCache] = useState(new Map<string, HTMLImageElement>());
    const imageSize = 5;

    // positionNodes places each node in respective columns based on their group.
    const positionNodes = () => {
        let prev = graphData.nodes[0].group;
        let yPosition = 0;
        let xPosition = 0;
        const rowSpacing = 25;
        const colSpacing = 80;

        // Position the nodes in columns based on their group
        graphData.nodes.forEach((node: { group: string; y: number; x: number; }) => {
            if (prev != node.group) {
                xPosition++;
                yPosition = 0;
                prev = node.group;
            }
            node.x = xPosition * colSpacing;
            node.y = yPosition * rowSpacing;
            yPosition++;
        });
    }

    // cacheIcons updates the iconCache with the image files.
    const cacheIcons = () => {
        // Add icons to the iconCache
        const newIconCache = new Map<string, HTMLImageElement>();
        graphData.nodes.forEach((node: { icon: string; }) => {
            // Add icons to the icon cache
            if (node.icon && !newIconCache.has(node.icon)) {
                const icon = new Image(imageSize, imageSize);
                icon.src = require("../../assets/icons/" + node.icon);
                icon.onload = () => {
                    newIconCache.set(node.icon, icon);
                };
            }
        });
        setIconCache(newIconCache);
    };

    // Used for rendering the nodes in a fixed position
    useEffect(() => {
        if (!graphRef.current || graphData.nodes.length == 0) return;

        positionNodes();
        cacheIcons();

        // Position canvas
        graphRef.current.zoom(4);
        graphRef.current.centerAt(15.5, 0);

        // Turn off all forces
        graphRef.current.d3Force("charge", null);
        graphRef.current.d3Force("link", null);
        graphRef.current.d3Force("x", null);
        graphRef.current.d3Force("y", null);
        graphRef.current.d3Force("collision", null);
    }, [graphData]);

    // getFontSize is a helper function to calculate font size.
    const getFontSize = (globalScale: number) => {
        return 12 / globalScale;
    }

    // getNodeSize gets the width and height of a node.
    const getNodeSize = (label: string, globalScale: number, ctx: CanvasRenderingContext2D) => {
        const padding = 8;
        const textWidth = ctx.measureText(label).width;

        return {width: textWidth + padding * 2, height: getFontSize(globalScale) + padding * 2};
    }

    return (
        <Box sx={{display: "flex", alignItems: "center", justifyContent: "center"}}>
            <ForceGraph2D
                ref={graphRef}
                enableNodeDrag={false}
                graphData={graphData}
                height={500}
                width={900}
                nodeRelSize={15}

                // BFS when clicking a node
                onNodeClick={(node, event) => {
                    if (!node) return;
                    const newHighlightLinks = new Set<Link>();
                    const queue: GraphNode[] = [node];
                    const visited = new Set<Link>();

                    while (queue.length > 0) {
                        const n: GraphNode = queue.shift() as GraphNode;

                        graphData.links.forEach((link: Link) => {
                            if (link.source.id == n.id && !visited.has(link)) {
                                visited.add(link);
                                queue.push(link.target);
                                newHighlightLinks.add(link);
                            }
                        });
                    }
                    setHighlightLinks(newHighlightLinks);
                }}

                // Zero does not work as a value
                linkWidth={() => 0.000001}
                linkCanvasObjectMode={() => "before"}

                // Create curved lines between nodes, and connect links to the edges of nodes instead of the center
                linkCanvasObject={(link: Link, ctx, globalScale) => {
                    const source = link.source as GraphNode;
                    const target = link.target as GraphNode;

                    if (!source || !target) return;

                    const sourcePoint = getNodeSize(source.name, globalScale, ctx);
                    const targetPoint = getNodeSize(target.name, globalScale, ctx);
                    const midX = (source.x + target.x) / 2;
                    const midY = (source.y + target.y) / 2;

                    // Offset the midpoint to create a slight curve
                    const curveOffset = 15;
                    const controlX = midX + (target.y - source.y) / curveOffset;

                    const direction = source.y - target.y;
                    let curveDirection = (target.x - source.x);
                    if (direction < 0) {
                        curveDirection *= -1;
                    }
                    const controlY = midY + curveDirection / curveOffset;

                    ctx.beginPath();
                    ctx.moveTo(source.x + sourcePoint.width / 2, source.y);
                    ctx.quadraticCurveTo(controlX, controlY, target.x-targetPoint.width/2, target.y);
                    ctx.strokeStyle = (highlightLinks.has(link)) ? "#000" : "rgba(200,200,200,0.8)";
                    ctx.lineWidth = 0.5;
                    ctx.stroke();
                }}

                nodeCanvasObject={(node: GraphNode, ctx: CanvasRenderingContext2D, globalScale: number): void => {
                    const label = node.name;
                    const {width, height} = getNodeSize(label, globalScale, ctx);
                    const borderRadius = 2;

                    ctx.save();
                    ctx.shadowColor = "rgba(0,0,0,0.5)";
                    ctx.shadowBlur = 4;
                    ctx.shadowOffsetX = 2;
                    ctx.shadowOffsetY = 2;

                    // drawRoundedRect renders a rounded rectangle around the node's coordinates.
                    const drawRoundedRect = (x: number, y: number, width: number, height: number, radius: number): void => {
                        ctx.beginPath();
                        ctx.moveTo(x + radius, y);
                        ctx.lineTo(x + width - radius, y);
                        ctx.quadraticCurveTo(x + width, y, x + width, y + radius);
                        ctx.lineTo(x + width, y + height - radius);
                        ctx.quadraticCurveTo(x + width, y + height, x + width - radius, y + height);
                        ctx.lineTo(x + radius, y + height);
                        ctx.quadraticCurveTo(x, y + height, x, y + height - radius);
                        ctx.lineTo(x, y + radius);
                        ctx.quadraticCurveTo(x, y, x + radius, y);
                        ctx.closePath();
                    };

                    // Draw the filled rectangle around the coordinates
                    ctx.fillStyle = "#FFFFFF";
                    drawRoundedRect(node.x - width / 2, node.y - height / 2, width, height, borderRadius);
                    ctx.fill();

                    // Render a black border around the rectangle
                    ctx.restore();
                    ctx.strokeStyle = "black"
                    ctx.lineWidth = 0.5;
                    ctx.stroke();

                    // Render the icon next to the node text
                    const icon = iconCache.get(node.icon!);
                    if (icon) {
                        ctx.drawImage(icon, 1.5 + node.x - width / 2, node.y - icon.height / 2, imageSize, imageSize);
                    }

                    // Render text
                    ctx.font = `${getFontSize(globalScale)}px Sans-Serif`;
                    ctx.textAlign = "center";
                    ctx.textBaseline = "middle";
                    ctx.fillStyle = "black";
                    ctx.fillText(label, node.x, node.y)
                }}
                nodeCanvasObjectMode={() => "replace"}
            />
        </Box>
    );
};

export default Graph;
