import React, { useState } from "react";
import type { InitializedPanelNode } from "../../panels";
import { calculateRotationQuadrant, FlipDirection } from "../../panels";
import { getTransformProps } from "../../visualizations/image-visualization/utils";
import type { InferenceDataForType } from "../types";
import type { ClassId, Detection } from "./schemas";

export function DetectionResultsVisualization({
  panel,
  results,
}: {
  panel: InitializedPanelNode;
  results: InferenceDataForType<
    "object-detection" | "zero-shot-object-detection"
  >;
}) {
  const {
    inferenceRotationDeg: rotationDeg,
    inferenceFlipDirection: flipDirection,
    hiddenObjectClassNames,
    showDetectionBoundingBoxes,
    showDetectionClassNames,
  } = panel;

  const [hoveredClassId, setHoveredClassId] = useState<ClassId | null>(null);

  function createPointerEnterHandler(classId: ClassId) {
    return function handlePointerEnter() {
      setHoveredClassId(classId);
    };
  }

  function handlePointerLeave() {
    setHoveredClassId(null);
  }

  const hoveredClass = results.detections.find(
    (detection) => detection.classId === hoveredClassId,
  );

  return (
    <svg
      viewBox={`0 0 ${results.imageWidth} ${results.imageHeight}`}
      {...getTransformProps({
        rotationDeg,
        flipDirection,
      })}
    >
      {hoveredClass !== undefined && (
        <g>
          <text
            style={{
              transform: calculateTextTransform({
                bboxWidth: results.imageWidth,
                bboxHeight: results.imageHeight,
                rotationDeg,
                flipDirection,
              }),
            }}
            color={computeDetectionClassColor(hoveredClass)}
            fill="currentcolor"
            stroke="currentcolor"
            x={0}
            y={0}
            dx={3}
            dy={24}
            fontSize={24}
            fontFamily="monospace"
          >
            {hoveredClass.className}
          </text>
        </g>
      )}
      {results.detections.flatMap((detection, index) => {
        const isNotHoveredClass =
          hoveredClassId !== null && detection.classId !== hoveredClassId;

        if (
          isNotHoveredClass ||
          hiddenObjectClassNames.includes(detection.className)
        ) {
          return [];
        }

        return (
          <g
            key={index}
            color={computeDetectionClassColor(detection)}
            fill="currentcolor"
            stroke="currentcolor"
            onPointerEnter={createPointerEnterHandler(detection.classId)}
            onPointerLeave={handlePointerLeave}
          >
            <rect
              x={detection.boundingBox.topLeftX}
              y={detection.boundingBox.topLeftY}
              width={detection.boundingBox.width}
              height={detection.boundingBox.height}
              strokeWidth={3}
              stroke={showDetectionBoundingBoxes ? undefined : "none"}
              fill="none"
            />
            {detection.polygons.map((polygon, index) => (
              <polygon
                key={index}
                fillOpacity={0.25}
                points={polygon.points.map(({ x, y }) => `${x},${y}`).join(" ")}
              />
            ))}
            {showDetectionClassNames &&
              detection.classId !== hoveredClassId && (
                <text
                  style={{
                    transformOrigin: [
                      `${detection.boundingBox.topLeftX}px`,
                      `${detection.boundingBox.topLeftY}px`,
                    ].join(" "),
                    transform: calculateTextTransform({
                      bboxWidth: detection.boundingBox.width,
                      bboxHeight: detection.boundingBox.height,
                      rotationDeg,
                      flipDirection,
                    }),
                  }}
                  x={detection.boundingBox.topLeftX}
                  y={detection.boundingBox.topLeftY}
                  dx={3}
                  dy={15}
                  fontSize={14}
                  fontFamily="monospace"
                >
                  {detection.className}
                </text>
              )}
          </g>
        );
      })}
    </svg>
  );
}

function calculateTextTransform({
  bboxWidth,
  bboxHeight,
  rotationDeg,
  flipDirection,
}: {
  bboxWidth: number;
  bboxHeight: number;
  rotationDeg: number;
  flipDirection: FlipDirection | null;
}): string {
  let scaleX = 1;
  let scaleY = 1;
  let isScaled = false;

  if (flipDirection === FlipDirection.X) {
    scaleX = -1;

    isScaled = true;
  } else if (flipDirection === FlipDirection.Y) {
    scaleY = -1;

    isScaled = true;
  }

  const rotationQuadrant = calculateRotationQuadrant(rotationDeg);

  let translateX = 0;
  let translateY = 0;

  if (flipDirection === null) {
    if (rotationQuadrant === 0) {
      /* no translation */
    } else if (rotationQuadrant === 1) {
      translateY = bboxHeight;
    } else if (rotationQuadrant === 2) {
      translateX = bboxWidth;
      translateY = bboxHeight;
    } else {
      translateX = bboxWidth;
    }
  } else if (flipDirection === FlipDirection.X) {
    if (rotationQuadrant === 0) {
      translateX = bboxWidth;
    } else if (rotationQuadrant === 1) {
      translateX = bboxWidth;
      translateY = bboxHeight;
    } else if (rotationQuadrant === 2) {
      translateY = bboxHeight;
    } else {
      /* no translation */
    }
  } else {
    if (rotationQuadrant === 0) {
      translateY = bboxHeight;
    } else if (rotationQuadrant === 1) {
      /* no translation */
    } else if (rotationQuadrant === 2) {
      translateX = bboxWidth;
    } else {
      translateX = bboxWidth;
      translateY = bboxHeight;
    }
  }

  const rotationMagnitude = isScaled ? 1 : -1;

  return [
    `translate(${translateX}px, ${translateY}px)`,
    `rotate(${rotationMagnitude * rotationDeg}deg)`,
    `scale(${scaleX}, ${scaleY})`,
  ].join(" ");
}

const MAX_HUE_DEG = 360;
const HUE_SECTIONS = 6;
const HUE_SECTION_SIZE_DEG = MAX_HUE_DEG / HUE_SECTIONS;
const LIGHTNESS_PCT = 65;
const CHROMA_PCT = 60;

export function computeDetectionClassColor(
  detection: Readonly<Pick<Detection, "classId" | "className">>,
): string {
  const codePointsSum = [...detection.className].reduce(
    (currentSum, codePoint) => currentSum + (codePoint.codePointAt(0) ?? 0),
    0,
  );

  const hueSection =
    (typeof detection.classId === "number"
      ? detection.classId
      : codePointsSum) % HUE_SECTIONS;

  const hueDeg =
    (codePointsSum % HUE_SECTION_SIZE_DEG) + hueSection * HUE_SECTION_SIZE_DEG;

  return `oklch(${LIGHTNESS_PCT}% ${CHROMA_PCT}% ${hueDeg})`;
}
