// src/components/SQLLineageFlowVisualizer.js
import React, { useCallback, useLayoutEffect } from 'react';
import ReactFlow, {
  MiniMap,
  Controls,
  Background,
  useNodesState,
  useEdgesState,
  MarkerType,
} from 'reactflow';
import 'reactflow/dist/style.css';

// Import our custom nodes
import TableNode from './custom-nodes/TableNode';
import CTENode from './custom-nodes/CTENode';
import OutputNode from './custom-nodes/OutputNode';
import SubqueryNode from './custom-nodes/SubqueryNode';

// Import our custom edge
import SelfLoopEdge from './custom-edges/SelfLoopEdge';

// Define node types
const nodeTypes = {
  tableNode: TableNode,
  cteNode: CTENode,
  outputNode: OutputNode,
  subqueryNode: SubqueryNode,
};

// Define edge types
const edgeTypes = {
  selfLoop: SelfLoopEdge,
};

const calculateSubqueryLevel = (subqueryId, allSubqueries, allCtes, levels, visited = new Set(), recursionPath = new Set()) => {
  // Check for recursion
  if (recursionPath.has(subqueryId)) {
    let maxNonRecursiveLevel = 0;
    const subquery = allSubqueries.find(sq => sq.id === subqueryId);
    if (subquery && subquery.referencedSubqueries) {
      subquery.referencedSubqueries.forEach(refSq => {
        if (refSq !== subqueryId && levels[`subquery-${refSq}`] !== undefined) {
          maxNonRecursiveLevel = Math.max(maxNonRecursiveLevel, levels[`subquery-${refSq}`]);
        }
      });
    }
    
    if (levels[`subquery-${subqueryId}`] === undefined) {
      levels[`subquery-${subqueryId}`] = maxNonRecursiveLevel + 1;
    }
    return;
  }
  
  if (visited.has(subqueryId)) return;
  visited.add(subqueryId);
  
  recursionPath.add(subqueryId);
  
  const subquery = allSubqueries.find(sq => sq.id === subqueryId);
  if (!subquery) {
    recursionPath.delete(subqueryId);
    return;
  }
  
  let maxDependencyLevel = 0;
  
  // Check levels of source tables (always level 0)
  if (subquery.sourceTables && subquery.sourceTables.length > 0) {
    // Tables are always at level 0, so this subquery should be at least level 1
    maxDependencyLevel = Math.max(maxDependencyLevel, 0);
  }
  
  // Check levels of referenced CTEs
  if (subquery.referencedCTEs && subquery.referencedCTEs.length > 0) {
    subquery.referencedCTEs.forEach(refCte => {
      // Find the actual CTE with correct case
      const matchedCte = allCtes.find(c => c.name.toLowerCase() === refCte.toLowerCase());
      const cteName = matchedCte ? matchedCte.name : refCte;
      
      // If CTE level not calculated yet, calculate it first
      if (levels[`cte-${cteName}`] === undefined) {
        calculateCTELevel(cteName, allCtes, allSubqueries, levels, new Set(), new Set());
      }
      
      maxDependencyLevel = Math.max(maxDependencyLevel, levels[`cte-${cteName}`] || 0);
    });
  }
  
  // Check levels of referenced subqueries
  if (subquery.referencedSubqueries && subquery.referencedSubqueries.length > 0) {
    subquery.referencedSubqueries.forEach(refSq => {
      // Skip self-reference during level calculation
      if (refSq === subqueryId) return;
      
      if (levels[`subquery-${refSq}`] === undefined) {
        calculateSubqueryLevel(refSq, allSubqueries, allCtes, levels, visited, new Set(recursionPath));
      }
      maxDependencyLevel = Math.max(maxDependencyLevel, levels[`subquery-${refSq}`] || 0);
    });
  }
  
  // This subquery's level is one more than its highest dependency
  levels[`subquery-${subqueryId}`] = maxDependencyLevel + 1;
  
  recursionPath.delete(subqueryId);
};

const calculateCTELevel = (cteName, allCtes, allSubqueries, levels, visited = new Set(), recursionPath = new Set()) => {
  // Check if we're in a recursive path
  if (recursionPath.has(cteName)) {
    // We've detected a recursive CTE - break the recursion
    // Set its level to be at least 1 more than any non-recursive dependencies
    let maxNonRecursiveLevel = 0;
    
    // Get all referenced CTEs except itself
    const cte = allCtes.find(c => c.name === cteName);
    if (cte && cte.referencedCTEs) {
      cte.referencedCTEs.forEach(refCte => {
        if (refCte !== cteName && levels[`cte-${refCte}`] !== undefined) {
          maxNonRecursiveLevel = Math.max(maxNonRecursiveLevel, levels[`cte-${refCte}`]);
        }
      });
    }
    
    // If no level is already set, set it to maxNonRecursiveLevel + 1
    if (levels[`cte-${cteName}`] === undefined) {
      levels[`cte-${cteName}`] = maxNonRecursiveLevel + 1;
    }
    return;
  }
  
  // Prevent processing the same CTE multiple times
  if (visited.has(cteName)) return;
  visited.add(cteName);
  
  // Add to recursion path to detect cycles
  recursionPath.add(cteName);
  
  const cte = allCtes.find(c => c.name === cteName);
  if (!cte) {
    recursionPath.delete(cteName);
    return;
  }
  
  let maxDependencyLevel = 0;
  
  // Check levels of source tables (always level 0)
  if (cte.sourceTables && cte.sourceTables.length > 0) {
    // Tables are always at level 0, so this CTE should be at least level 1
    maxDependencyLevel = Math.max(maxDependencyLevel, 0);
  }
  
  // Check levels of referenced CTEs
  if (cte.referencedCTEs && cte.referencedCTEs.length > 0) {
    cte.referencedCTEs.forEach(refCte => {
      // Skip self-reference during level calculation (will handle separately)
      if (refCte === cteName) return;
      
      // Find the actual CTE with correct case
      const matchedCte = allCtes.find(c => c.name.toLowerCase() === refCte.toLowerCase());
      const refCteName = matchedCte ? matchedCte.name : refCte;
      
      // If CTE level not calculated yet, calculate it first
      if (levels[`cte-${refCteName}`] === undefined) {
        calculateCTELevel(refCteName, allCtes, allSubqueries, levels, visited, new Set(recursionPath));
      }
      
      maxDependencyLevel = Math.max(maxDependencyLevel, levels[`cte-${refCteName}`] || 0);
    });
  }
  
  // Check levels of referenced subqueries
  if (cte.referencedSubqueries && cte.referencedSubqueries.length > 0) {
    cte.referencedSubqueries.forEach(refSq => {
      if (levels[`subquery-${refSq}`] === undefined) {
        calculateSubqueryLevel(refSq, allSubqueries, allCtes, levels, new Set(), new Set());
      }
      
      maxDependencyLevel = Math.max(maxDependencyLevel, levels[`subquery-${refSq}`] || 0);
    });
  }
  
  // This CTE's level is one more than its highest dependency
  levels[`cte-${cteName}`] = maxDependencyLevel + 1;
  
  // Remove from recursion path when done
  recursionPath.delete(cteName);
};

const SQLLineageFlowVisualizer = ({ data }) => {
  const { sourceTables, ctes, subqueries, outputColumns, detectedDialect } = data;
  const [nodes, setNodes, onNodesChange] = useNodesState([]);
  const [edges, setEdges, onEdgesChange] = useEdgesState([]);

  // Function to build the graph layout
  const buildGraph = useCallback(() => {
    const newNodes = [];
    const newEdges = [];
    
    // Calculate levels for each node for proper left-to-right layout
    const levels = {};

    const compatibleFinalQuerySources = Array.isArray(data.finalQuerySources) 
      ? data.finalQuerySources.map(source => {
          // Check if source is already in the new format
          if (typeof source === 'object' && source.name && source.type) {
            return source;
          }
          // Convert old string format to new object format
          const sourceStr = String(source);
          // Determine type based on presence in collections
          let type = 'table';
          if (ctes.some(cte => cte.name === sourceStr)) {
            type = 'cte';
          }
          return { name: sourceStr, type };
        })
      : [];
    
    // Source tables are level 0
    sourceTables.forEach(table => {
      levels[`table-${table}`] = 0;
    });
    
    // Calculate CTE and subquery levels based on dependencies
    // First pass: calculate subquery levels
    if (subqueries) {
      subqueries.forEach(subquery => {
        if (levels[`subquery-${subquery.id}`] === undefined) {
          calculateSubqueryLevel(subquery.id, subqueries, ctes, levels);
        }
      });
    }
    
    // Second pass: calculate CTE levels taking subqueries into account
    ctes.forEach(cte => {
      if (levels[`cte-${cte.name}`] === undefined) {
        calculateCTELevel(cte.name, ctes, subqueries, levels);
      }
    });
    
    // After levels are calculated, we may need another pass to resolve
    // interdependencies between CTEs and subqueries
    let changed = true;
    let safetyCounter = 0;
    const MAX_ITERATIONS = 10;  // Prevent infinite loops
    
    while (changed && safetyCounter < MAX_ITERATIONS) {
      changed = false;
      safetyCounter++;
      
      // Check if any CTE needs to be moved up due to subquery references
      ctes.forEach(cte => {
        if (cte.referencedSubqueries && cte.referencedSubqueries.length > 0) {
          let maxRefLevel = 0;
          cte.referencedSubqueries.forEach(refSq => {
            maxRefLevel = Math.max(maxRefLevel, levels[`subquery-${refSq}`] || 0);
          });
          
          if (levels[`cte-${cte.name}`] <= maxRefLevel) {
            const newLevel = maxRefLevel + 1;
            if (newLevel !== levels[`cte-${cte.name}`]) {
              levels[`cte-${cte.name}`] = newLevel;
              changed = true;
            }
          }
        }
      });
      
      // Check if any subquery needs to be moved up due to CTE references
      if (subqueries) {
        subqueries.forEach(subquery => {
          if (subquery.referencedCTEs && subquery.referencedCTEs.length > 0) {
            let maxRefLevel = 0;
            subquery.referencedCTEs.forEach(refCte => {
              maxRefLevel = Math.max(maxRefLevel, levels[`cte-${refCte}`] || 0);
            });
            
            if (levels[`subquery-${subquery.id}`] <= maxRefLevel) {
              const newLevel = maxRefLevel + 1;
              if (newLevel !== levels[`subquery-${subquery.id}`]) {
                levels[`subquery-${subquery.id}`] = newLevel;
                changed = true;
              }
            }
          }
        });
      }
    }
    
    // Find the maximum level for positioning the output
    const maxLevelValue = Object.values(levels).length > 0 
      ? Math.max(...Object.values(levels)) 
      : 0;
    
    // Position calculation
    const baseY = 100;
    const baseX = 50;
    const xSpacing = 300;
    const ySpacing = 120;
    
    // Map to store node counts at each level (for vertical positioning)
    const levelCounts = {};
    const levelPositions = {};
    
    // Initialize level counters
    for (let i = 0; i <= maxLevelValue; i++) {
      levelCounts[i] = 0;
      levelPositions[i] = 0;
    }
    
    // Count nodes per level
    Object.values(levels).forEach(level => {
      levelCounts[level] = (levelCounts[level] || 0) + 1;
    });
    
    // Add source table nodes
    sourceTables.forEach((table) => {
      const level = levels[`table-${table}`] || 0;
      const position = levelPositions[level]++;
      
      newNodes.push({
        id: `table-${table}`,
        type: 'tableNode',
        position: { 
          x: baseX + (level * xSpacing), 
          y: baseY + (position * ySpacing) 
        },
        data: { label: table }
      });
    });
    
    // Add CTE nodes
    ctes.forEach((cte) => {
      const level = levels[`cte-${cte.name}`] || 0;
      const position = levelPositions[level]++;
      const isRecursive = cte.referencedCTEs && cte.referencedCTEs.includes(cte.name);
      
      newNodes.push({
        id: `cte-${cte.name}`,
        type: 'cteNode',
        position: { 
          x: baseX + (level * xSpacing), 
          y: baseY + (position * ySpacing) 
        },
        data: { 
          label: cte.name,
          sourceTables: cte.sourceTables,
          referencedCTEs: cte.referencedCTEs,
          referencedSubqueries: cte.referencedSubqueries,
          query: cte.query,
          isRecursive: isRecursive
        }
      });
      
      // Add edges from source tables to this CTE
      cte.sourceTables.forEach(sourceTable => {
        newEdges.push({
          id: `edge-${sourceTable}-to-${cte.name}`,
          source: `table-${sourceTable}`,
          target: `cte-${cte.name}`,
          markerEnd: {
            type: MarkerType.ArrowClosed,
          },
          style: { stroke: '#6366f1' }
        });
      });
      
      // Add edges from referenced CTEs to this CTE
      if (cte.referencedCTEs) {
        cte.referencedCTEs.forEach(refCte => {
          // Handle self-references with custom edge type
          if (refCte === cte.name) {
            newEdges.push({
              id: `edge-${refCte}-to-self-${cte.name}`,
              source: `cte-${cte.name}`,
              target: `cte-${cte.name}`,
              type: 'selfLoop',  // Use our custom edge type
              markerEnd: {
                type: MarkerType.ArrowClosed,
              },
              style: { stroke: '#8b5cf6' },
              data: { isRecursive: true }
            });
            return;
          }
          
          // Find the actual CTE name with correct case from the ctes array
          const referencedCte = ctes.find(c => c.name.toLowerCase() === refCte.toLowerCase());
          const sourceName = referencedCte ? referencedCte.name : refCte;
          
          newEdges.push({
            id: `edge-${sourceName}-to-${cte.name}`,
            source: `cte-${sourceName}`,
            target: `cte-${cte.name}`,
            markerEnd: {
              type: MarkerType.ArrowClosed,
            },
            style: { stroke: '#8b5cf6' }
          });
        });
      }
      
      // Add edges from referenced subqueries to this CTE
      if (cte.referencedSubqueries && cte.referencedSubqueries.length > 0) {
        cte.referencedSubqueries.forEach(refSq => {
          newEdges.push({
            id: `edge-subquery-${refSq}-to-${cte.name}`,
            source: `subquery-${refSq}`,
            target: `cte-${cte.name}`,
            markerEnd: {
              type: MarkerType.ArrowClosed,
            },
            style: { stroke: '#c2410c' }
          });
        });
      }
    });
    
    // Add subquery nodes
    if (subqueries) {
      subqueries.forEach((subquery) => {
        const level = levels[`subquery-${subquery.id}`] || 0;
        const position = levelPositions[level]++;
        const isRecursive = subquery.referencedSubqueries && 
                            subquery.referencedSubqueries.includes(subquery.id);
        
        newNodes.push({
          id: `subquery-${subquery.id}`,
          type: 'subqueryNode',
          position: { 
            x: baseX + (level * xSpacing), 
            y: baseY + (position * ySpacing) 
          },
          data: { 
            alias: subquery.alias,
            id: subquery.id,
            sourceTables: subquery.sourceTables,
            referencedCTEs: subquery.referencedCTEs,
            referencedSubqueries: subquery.referencedSubqueries,
            query: subquery.query,
            isRecursive: isRecursive
          }
        });
        
        // Add edges from source tables to this subquery
        if (subquery.sourceTables) {
          subquery.sourceTables.forEach(sourceTable => {
            newEdges.push({
              id: `edge-${sourceTable}-to-subquery-${subquery.id}`,
              source: `table-${sourceTable}`,
              target: `subquery-${subquery.id}`,
              markerEnd: {
                type: MarkerType.ArrowClosed,
              },
              style: { stroke: '#c2410c' }
            });
          });
        }
        
        // Add edges from referenced CTEs to this subquery
        if (subquery.referencedCTEs) {
          subquery.referencedCTEs.forEach(refCte => {
            // Find the actual CTE name with correct case from the ctes array
            const referencedCte = ctes.find(c => c.name.toLowerCase() === refCte.toLowerCase());
            const sourceName = referencedCte ? referencedCte.name : refCte;
            
            newEdges.push({
              id: `edge-${sourceName}-to-subquery-${subquery.id}`,
              source: `cte-${sourceName}`,
              target: `subquery-${subquery.id}`,
              markerEnd: {
                type: MarkerType.ArrowClosed,
              },
              style: { stroke: '#c2410c' }
            });
          });
        }
        
        // Add edges from referenced subqueries to this subquery
        if (subquery.referencedSubqueries) {
          subquery.referencedSubqueries.forEach(refSq => {
            // Handle self-references with custom edge type
            if (refSq === subquery.id) {
              newEdges.push({
                id: `edge-subquery-${refSq}-to-self-${subquery.id}`,
                source: `subquery-${subquery.id}`,
                target: `subquery-${subquery.id}`,
                type: 'selfLoop',  // Use our custom edge type
                markerEnd: {
                  type: MarkerType.ArrowClosed,
                },
                style: { stroke: '#c2410c' },
                data: { isRecursive: true }
              });
              return;
            }
            
            newEdges.push({
              id: `edge-subquery-${refSq}-to-subquery-${subquery.id}`,
              source: `subquery-${refSq}`,
              target: `subquery-${subquery.id}`,
              markerEnd: {
                type: MarkerType.ArrowClosed,
              },
              style: { stroke: '#c2410c' }
            });
          });
        }
      });
    }

    // Position for output - center vertically relative to the max level
    const outputY = levelCounts[maxLevelValue] > 0
      ? baseY + ((levelCounts[maxLevelValue] - 1) * ySpacing / 2)
      : baseY;
    
    // Add output node
    newNodes.push({
      id: 'output',
      type: 'outputNode',
      position: { 
        x: baseX + ((maxLevelValue + 1) * xSpacing), 
        y: outputY 
      },
      data: { columns: outputColumns }
    });

    // Add connections to output
    if (compatibleFinalQuerySources && compatibleFinalQuerySources.length > 0) {
      // Use the finalQuerySources field from the API to determine connections
      compatibleFinalQuerySources.forEach(source => {
        // Check if this source is a CTE, subquery or a direct table
        if (source.type === 'cte') {
          newEdges.push({
            id: `edge-${source.name}-to-output`,
            source: `cte-${source.name}`,
            target: 'output',
            markerEnd: {
              type: MarkerType.ArrowClosed,
            },
            style: { stroke: '#10b981' }
          });
        } else if (source.type === 'table') {
          newEdges.push({
            id: `edge-${source.name}-to-output`,
            source: `table-${source.name}`,
            target: 'output',
            markerEnd: {
              type: MarkerType.ArrowClosed,
            },
            style: { stroke: '#10b981' }
          });
        } else if (source.type === 'subquery') {
          newEdges.push({
            id: `edge-subquery-${source.name}-to-output`,
            source: `subquery-${source.name}`,
            target: 'output',
            markerEnd: {
              type: MarkerType.ArrowClosed,
            },
            style: { stroke: '#10b981' }
          });
        }
      });
    } else {
      // Connect the last CTE to the output
      if (ctes.length > 0) {
        const lastCte = ctes[ctes.length - 1];
        newEdges.push({
          id: `edge-${lastCte.name}-to-output`,
          source: `cte-${lastCte.name}`,
          target: 'output',
          markerEnd: {
            type: MarkerType.ArrowClosed,
          },
          style: { stroke: '#10b981' }
        });
      } else if (sourceTables.length > 0) {
        // If no CTEs, connect source tables directly to output
        sourceTables.forEach(table => {
          newEdges.push({
            id: `edge-${table}-to-output`,
            source: `table-${table}`,
            target: 'output',
            markerEnd: {
              type: MarkerType.ArrowClosed,
            },
            style: { stroke: '#10b981' }
          });
        });
      }
    }
    
    setNodes(newNodes);
    setEdges(newEdges);
  }, [sourceTables, ctes, subqueries, outputColumns, data.finalQuerySources]); 
  
  // Build the graph on data change
  useLayoutEffect(() => {
    buildGraph();
  }, [buildGraph]);
  
  return (
    <div style={{ height: '600px', width: '100%' }}>
      <div style={{ marginBottom: '10px' }}>
        {detectedDialect && (
          <div style={{ 
            padding: '8px', 
            background: '#f3f4f6', 
            borderRadius: '4px', 
            marginBottom: '10px',
            display: 'inline-block'
          }}>
            <span style={{ fontWeight: 500 }}>Detected Dialect:</span>
            <span style={{ marginLeft: '8px' }}>{detectedDialect}</span>
          </div>
        )}
      </div>
      
      <div style={{ height: '550px', border: '1px solid #e5e7eb', borderRadius: '8px', overflow: 'hidden' }}>
        <ReactFlow
          nodes={nodes}
          edges={edges}
          onNodesChange={onNodesChange}
          onEdgesChange={onEdgesChange}
          nodeTypes={nodeTypes}
          edgeTypes={edgeTypes}
          fitView
          attributionPosition="bottom-right"
        >
          <Controls />
          <MiniMap />
          <Background color="#aaa" gap={16} />
        </ReactFlow>
      </div>
    </div>
  );
};

export default SQLLineageFlowVisualizer;