import React, { StrictMode, createContext, useContext, useState, useCallback, useEffect, useRef } from 'react';
import { createRoot } from 'react-dom/client';
import ReactFlow, { Controls, Panel, useNodesState, useEdgesState, addEdge, updateEdge } from 'reactflow';
import { useHotkeys } from 'react-hotkeys-hook';
import { Tooltip } from 'react-tooltip';

import 'reactflow/dist/base.css';

import { ActionCableProvider, ActionCableContext } from './Shared/ActionCable';
import Library from './Controls/Library';
import MagicSetup from './Controls/MagicSetup';
import CodemirrorModal from './Modals/CodemirrorModal';
import EntryNode from './Nodes/EntryNode';
import InvalidNode from './Nodes/InvalidNode';
import CustomCodeNode from './Nodes/CustomCodeNode';
import SplitNode from './Nodes/SplitNode';
import SessionNode from './Nodes/SessionNode';
import ConditionalNode from './Nodes/ConditionalNode';
import ApiNode from './Nodes/ApiNode';
import PromptTextNode from './Nodes/PromptTextNode';
import PromptImageNode from './Nodes/PromptImageNode';
import ExtractNode from './Nodes/ExtractNode';
import TakeNode from './Nodes/TakeNode';
import CollectNode from './Nodes/CollectNode';

const STANDARD_NODE_TYPES = {
  default: InvalidNode,
  entry: EntryNode,
  api: ApiNode,
  prompt_text: PromptTextNode,
  prompt_image: PromptImageNode,
  conditional: ConditionalNode,
  split: SplitNode,
  extract: ExtractNode,
  take: TakeNode,
  collect: CollectNode,
  custom_code: CustomCodeNode,
};

const SNAP_GRID = [10, 10];

export const NODE_DEFAULTS = { dragHandle: '[data-drag-handle]' };

const serializeGraph = (graph) => {
  graph ||= {};

  // remove non-serialized properties
  delete graph.viewport;
  graph.nodes = graph.nodes.map(({ width, height, positionAbsolute, selected, selectable, dragging, className, dragHandle, ...node }) => node); // prettier-ignore
  graph.edges = graph.edges.map(({ selected, ...edge }) => edge);

  // remove noop NodeChange/EdgeChange events
  graph.nodes = graph.nodes.filter((x) => !['select', 'position', 'dimensions'].includes(x.type));
  graph.edges = graph.edges.filter((x) => !['select'].includes(x.type));

  return graph;
};

const syncGraphInput = (input, graph, caller) => {
  if (!input || !graph) return;
  if (caller === 'ajax:before') console.log({ caller, graph });
  input.value = JSON.stringify(graph);
  input.dispatchEvent(new Event('change'));
};

export const ConnectedHandleIdsContext = createContext();
const ConnectedHandleIdsProvider = ({ children, connectedHandleIds }) => {
  return <ConnectedHandleIdsContext.Provider value={connectedHandleIds}>{children}</ConnectedHandleIdsContext.Provider>;
};

function App({ form, nodeTypes, initialNodes, initialEdges }) {
  const [instance, setInstance] = useState(null);
  const [openModal, setOpenModal] = useState(false);
  const [connectedHandleIds, setConnectedHandleIds] = useState({});
  const edgeUpdateSuccessful = useRef(true);

  const [nodes, setNodes, setOnNodesChange] = useNodesState(initialNodes);
  const [edges, setEdges, setOnEdgesChange] = useEdgesState(initialEdges);

  const actioncable = useContext(ActionCableContext);

  const graphInput = form.querySelector('input[data-field="graph"]');

  useEffect(() => {
    const idsWithConnections = {};

    edges.forEach((edge) => {
      // if the edge has a sourceHandle or a targetHandle, or both, set those values in connectedHandleIds to `true`
      if (edge.sourceHandle) idsWithConnections[edge.sourceHandle] = true;
      if (edge.targetHandle) idsWithConnections[edge.targetHandle] = true;
      if (edge.sourceHandle === null && edge.source?.startsWith('entry:')) idsWithConnections['entry'] = true;
    });

    setConnectedHandleIds(idsWithConnections);
  }, [edges]);

  useEffect(() => {
    if (form?.dataset?.remote) {
      form.addEventListener('ajax:before', () => {
        if (instance) syncGraphInput(graphInput, serializeGraph(instance.toObject()), 'ajax:before');
        form.dirty_form_controller.reset_state();
      });
    }

    if (form.dataset.blueprintId) {
      const channel = actioncable.subscriptions.create(
        { channel: 'BlueprintsChannel', id: form.dataset.blueprintId },
        {
          received: (data) => {
            if (data.command == 'replace_graph') {
              console.log('replacing graph', data.graph);
              instance.setNodes(data.graph.nodes);
              instance.setEdges(data.graph.edges);
              if (data.fitView !== false) setTimeout(() => instance.fitView({ duration: 300 }), 300);
              if (form.dirty_form_controller) setTimeout(() => form.dirty_form_controller.reset_state(), 300);
            }
          },
        },
      );
      return () => channel.unsubscribe();
    }
  }, [form, instance]);

  useHotkeys(
    'meta+a',
    () => {
      instance.setNodes(instance.getNodes().map((n) => (n.type === 'entry' ? n : { ...n, selected: true })));
      instance.setEdges(instance.getEdges().map((n) => ({ ...n, selected: true })));
    },
    { preventDefault: true },
  );

  const onEdgeUpdateStart = useCallback(() => {
    edgeUpdateSuccessful.current = false;
  }, []);

  const onEdgeUpdate = useCallback(
    (eds, connection) => {
      edgeUpdateSuccessful.current = true;
      setEdges((els) => {
        const edges = updateEdge(eds, connection, els);
        if (instance) {
          const graph = instance.toObject();
          graph.edges = edges;
          syncGraphInput(graphInput, serializeGraph(graph), 'onEdgeUpdate');
        }
        return edges;
      });
    },
    [instance],
  );

  const onEdgeUpdateEnd = useCallback(
    (_, edge) => {
      if (!edgeUpdateSuccessful.current) {
        setEdges((eds) => eds.filter((e) => e.id !== edge.id));
        if (instance) {
          const graph = instance.toObject();
          graph.edges = graph.edges.filter((e) => e.id !== edge.id);
          syncGraphInput(graphInput, serializeGraph(graph), 'onEdgeUpdateEnd');
        }
      }
      edgeUpdateSuccessful.current = true;
    },
    [instance],
  );

  const onInit = useCallback((instance) => {
    setInstance(instance);

    // reset dirty form state after reserializing, so order is consistent
    if (graphInput && form.dirty_form_controller) {
      if (instance) syncGraphInput(graphInput, serializeGraph(instance.toObject()), 'onInit');
      form.dirty_form_controller.reset_state();
    }

    // fit view to focused node
    const { hash } = window.location;
    if (hash) {
      const id = hash.split('focus=')[1];
      const node = instance.getNode(id);
      if (node) {
        instance.setNodes(instance.getNodes().map((n) => (n.id == node.id ? { ...n, selected: true } : n)));
        instance.fitView({ nodes: [node], duration: 300 });
      }
    }
  }, []);

  const onDragOver = useCallback((event) => {
    event.preventDefault();
    event.dataTransfer.dropEffect = 'move';
  }, []);

  const onDrop = useCallback(
    (event) => {
      event.preventDefault();

      // validate dropped element
      const type = event.dataTransfer.getData('application/reactflow');
      if (typeof type === 'undefined' || !type) return;

      const position = instance.screenToFlowPosition({ x: event.clientX, y: event.clientY });

      setNodes((nds) => [
        ...nds.map((node) => ({ ...node, selected: false })),
        { id: `node:${crypto.randomUUID()}`, data: {}, type, position, selected: true, ...NODE_DEFAULTS },
      ]);
    },
    [setNodes, instance],
  );

  const onNodeClick = useCallback(
    (event, node) => {
      const target = event.target.closest('[data-open-modal]') || event.target;
      if (target.dataset.openModal) setOpenModal({ node, mode: target.dataset.openModal, field: target.dataset.field });
    },
    [setOpenModal],
  );

  const onModalClose = useCallback(() => {
    // close modal
    setOpenModal(false);

    // force sync
    if (instance) syncGraphInput(graphInput, serializeGraph(instance.toObject()), 'onModalClose');
  }, [instance, setOpenModal]);

  // TODO: debounce setNodes
  const onModalChange = useCallback(
    (data) => {
      if (!openModal.node) return;

      setNodes((nds) =>
        nds.map((node) => {
          if (node.id === openModal.node.id) node.data = { ...node.data, ...data };
          return node;
        }),
      );

      if (instance) syncGraphInput(graphInput, serializeGraph(instance.toObject()), 'onModalChange');
    },
    [instance, openModal, setNodes],
  );

  const onNodesChange = useCallback(
    (nodes) => {
      setOnNodesChange(nodes);

      if (instance) {
        const graph = instance.toObject();
        graph.nodes = [...graph.nodes, ...nodes];
        syncGraphInput(graphInput, serializeGraph(graph), 'onNodesChange');
      }
    },
    [instance, setOnNodesChange],
  );

  const onConnect = useCallback(
    (connection) => {
      setEdges((eds) => {
        const edges = addEdge({ ...connection, id: `edge:${crypto.randomUUID()}` }, eds);
        const graph = instance.toObject();
        graph.edges = edges;
        if (instance) syncGraphInput(graphInput, serializeGraph(graph), 'onConnect');
        return edges;
      });
    },
    [instance, setEdges],
  );

  const onEdgesChange = useCallback(
    (edges) => {
      setOnEdgesChange(edges);

      if (instance) {
        const graph = instance.toObject();
        graph.edges = [...graph.edges, ...edges];
        syncGraphInput(graphInput, serializeGraph(graph), 'onEdgesChange');
      }
    },
    [instance],
  );

  const validateConnection = useCallback(
    (connection) => {
      // prevent connecting onto self
      if (connection.source == connection.target) return false;

      return true;
    },
    [instance],
  );

  return (
    <ConnectedHandleIdsProvider connectedHandleIds={connectedHandleIds}>
      <ReactFlow
        onInit={onInit}
        nodeTypes={nodeTypes}
        nodes={nodes}
        onNodesChange={onNodesChange}
        onNodeClick={onNodeClick}
        edges={edges}
        onEdgesChange={onEdgesChange}
        onEdgeUpdate={onEdgeUpdate}
        onEdgeUpdateStart={onEdgeUpdateStart}
        onEdgeUpdateEnd={onEdgeUpdateEnd}
        onConnect={onConnect}
        onDrop={onDrop}
        onDragOver={onDragOver}
        isValidConnection={validateConnection}
        selectNodesOnDrag={false}
        snapToGrid
        snapGrid={SNAP_GRID}
        minZoom={0.25}
        maxZoom={2.0}
        fitView
        fitViewOptions={{ maxZoom: 1.0 }}
        deleteKeyCode={['Delete', 'Backspace']}
        proOptions={{ hideAttribution: true }}
      >
        {!!openModal && <div className="modal-glass" onClick={() => setOpenModal(false)} />}

        <Panel position="top-left">
          <Library nodeTypes={nodeTypes} />
          <Tooltip id="node-tooltip" />
        </Panel>

        {form.dataset.graph === 'session_manager' && (
          <Panel position="top-left" style={{ width: '100%' }}>
            <MagicSetup
              blueprintId={form.dataset.blueprintId}
              openapiSpec={form.dataset.openapiSpec}
              sampleTeamIsActive={form.dataset.sampleTeamIsActive}
            />
          </Panel>
        )}

        <Panel position="top-right" style={{ zIndex: 11 }}>
          {!!openModal && <CodemirrorModal details={openModal} onClose={onModalClose} onChange={onModalChange} />}
        </Panel>
        <Controls showInteractive={false} />
      </ReactFlow>
    </ConnectedHandleIdsProvider>
  );
}

document.addEventListener('turbo:load', () => {
  document.querySelectorAll('form[data-graph]').forEach((form) => {
    if (form.dataset.initialized) document.querySelector(`#${form.dataset.initialized}`).remove();

    const GRAPH_TYPE = form.dataset.graph;
    const GRAPH = JSON.parse(form.querySelector('input[data-field="graph"]').value || '{}');

    // TODO: validate GRAPH
    let nodes = GRAPH.nodes || [];
    let edges = GRAPH.edges || [];

    // preprocess nodes
    nodes = nodes.map((node) => {
      // move source to data
      node.data.source = node.source;

      // add defaults
      node = { ...node, ...NODE_DEFAULTS };

      return node;
    });

    // ensure entry node exists and is valid
    if (!nodes.find((x) => x.type === 'entry')) nodes.push({ id: `entry:${crypto.randomUUID()}`, type: 'entry', position: { x: 0, y: 0 } });
    nodes.find((x) => x.type === 'entry').selectable = false;

    let nodeTypes = { ...STANDARD_NODE_TYPES };

    // customize nodeTypes per GRAPH_TYPE
    if (GRAPH_TYPE === 'session_manager') {
      nodeTypes['session'] = SessionNode;
    }

    console.log({ GRAPH_TYPE, GRAPH });

    // customize nodeTypes
    nodeTypes['entry'] = (props) => <EntryNode {...props} {...form.dataset} />;
    nodeTypes['api'] = (props) => <ApiNode {...props} api_host={form.dataset.apiHost} />;
    nodeTypes['prompt_text'] = (props) => <PromptTextNode {...props} playground_stream_url={form.dataset.playgroundStreamUrl} />;
    nodeTypes['prompt_image'] = (props) => <PromptImageNode {...props} playground_stream_url={form.dataset.playgroundStreamUrl} />;

    const container = document.createElement('div');
    container.classList = 'graph';
    container.id = `graph_${crypto.randomUUID()}`;

    createRoot(container).render(
      <StrictMode>
        <ActionCableProvider>
          <App form={form} nodeTypes={nodeTypes} initialNodes={nodes} initialEdges={edges} />
        </ActionCableProvider>
      </StrictMode>,
    );

    form.insertAdjacentElement('afterend', container);
    form.dataset.initialized = container.id;
  });
});
