import { useEffect } from 'react';
import dagre from '@dagrejs/dagre';
import { type Node, type Edge, useReactFlow, useNodesInitialized, useStore, Position } from '@xyflow/react';

type Direction = 'TB' | 'LR' | 'RL' | 'BT';

function getSourceHandlePosition(direction: Direction) {
  switch (direction) {
    case 'TB':
      return Position.Bottom;
    case 'BT':
      return Position.Top;
    case 'LR':
      return Position.Right;
    case 'RL':
      return Position.Left;
  }
}

function getTargetHandlePosition(direction: Direction) {
  switch (direction) {
    case 'TB':
      return Position.Top;
    case 'BT':
      return Position.Bottom;
    case 'LR':
      return Position.Left;
    case 'RL':
      return Position.Right;
  }
}

const dagreGraph = new dagre.graphlib.Graph().setDefaultEdgeLabel(() => ({}));

type LayoutAlgorithmOptions = {
  direction: Direction;
  spacing: [number, number];
};

type LayoutAlgorithm = (
  nodes: Node[],
  edges: Edge[],
  options: LayoutAlgorithmOptions,
) => { nodes: Node[]; edges: Edge[] };

const dagreLayout: LayoutAlgorithm = (nodes, edges, options) => {
  dagreGraph.setGraph({
    rankdir: options.direction,
    nodesep: options.spacing[0],
    ranksep: options.spacing[1],
  });

  const existingNodeIds = nodes.map((node) => node.id);

  dagreGraph.nodes().forEach((node) => {
    if (!existingNodeIds.includes(node)) {
      dagreGraph.removeNode(node);
    }
  });

  for (const node of nodes) {
    dagreGraph.setNode(node.id, {
      width: node.measured?.width ?? 0,
      height: node.measured?.height ?? 0,
    });
  }

  for (const edge of edges) {
    dagreGraph.setEdge(edge.source, edge.target);
  }

  dagre.layout(dagreGraph);

  const nextNodes = nodes.map((node) => {
    const { x, y } = dagreGraph.node(node.id);
    const position = {
      x: x - (node.measured?.width ?? 0) / 2,
      y: y - (node.measured?.height ?? 0) / 2,
    };

    return { ...node, position };
  });

  return { nodes: nextNodes, edges };
};

export function useAutoLayout(options: LayoutAlgorithmOptions) {
  const { setNodes, setEdges } = useReactFlow();
  const nodesInitialized = useNodesInitialized();

  const elements = useStore(
    (state) => ({
      nodes: state.nodes,
      edges: state.edges,
    }),
    compareElements,
  );

  useEffect(() => {
    if (!nodesInitialized || elements.nodes.length === 0) {
      return;
    }

    const nodes = elements.nodes.map((node) => ({ ...node }));
    const edges = elements.edges.map((edge) => ({ ...edge }));

    const { nodes: nextNodes, edges: nextEdges } = dagreLayout(nodes, edges, options);

    for (const node of nextNodes) {
      node.style = { ...node.style, opacity: 1 };
      node.sourcePosition = getSourceHandlePosition(options.direction);
      node.targetPosition = getTargetHandlePosition(options.direction);
    }

    for (const edge of edges) {
      edge.style = { ...edge.style, opacity: 1 };
    }

    setNodes(nextNodes);
    setEdges(nextEdges);
  }, [nodesInitialized, elements, options, setNodes, setEdges]);
}

type Elements = {
  nodes: Array<Node>;
  edges: Array<Edge>;
};

function compareElements(xs: Elements, ys: Elements) {
  return compareNodes(xs.nodes, ys.nodes) && compareEdges(xs.edges, ys.edges);
}

function compareNodes(xs: Array<Node>, ys: Array<Node>) {
  if (xs.length !== ys.length) {
    return false;
  }

  for (let i = 0; i < xs.length; i++) {
    const x = xs[i];
    const y = ys[i];

    if (!x || !y) {
      return false;
    }

    if (x.resizing || x.dragging) {
      return true;
    }
    if (x.measured?.width !== y.measured?.width || x.measured?.height !== y.measured?.height) {
      return false;
    }
  }

  return true;
}

function compareEdges(xs: Array<Edge>, ys: Array<Edge>) {
  if (xs.length !== ys.length) {
    return false;
  }

  for (let i = 0; i < xs.length; i++) {
    const x = xs[i];
    const y = ys[i];

    if (!x || !y) {
      return false;
    }

    if (x.source !== y.source || x.target !== y.target) {
      return false;
    }
    if (x?.sourceHandle !== y?.sourceHandle) {
      return false;
    }
    if (x?.targetHandle !== y?.targetHandle) {
      return false;
    }
  }

  return true;
}
