import { Circle } from "@mui/icons-material";
import {
  List,
  ListItem,
  ListItemText,
  Slider,
  Stack,
  Typography,
} from "@mui/material";
import type { Maybe } from "~/types";
import { usePlayerActions } from "../../actions";
import type { InitializedPanelNode } from "../../panels";
import type { ImageSegmentationInferenceResults } from "./types";

export function ImageSegmentationResultsControls({
  panel,
  results,
}: {
  panel: InitializedPanelNode;
  results: Maybe<ImageSegmentationInferenceResults>;
}) {
  const playerActions = usePlayerActions();

  function handleOpacityChange(_: unknown, newOpacity: number | number[]) {
    playerActions.setInferenceImageOpacity(panel, newOpacity as number);
  }

  const segmentationClasses = groupSegmentations(results);

  return (
    <>
      <Typography>Opacity</Typography>
      <Slider
        sx={{
          alignSelf: "center",
          width: (theme) => `calc(100% - ${theme.spacing(2.5)})`,
        }}
        min={0}
        max={1}
        step={0.01}
        value={panel.inferenceImageOpacity}
        onChange={handleOpacityChange}
      />
      <List dense>
        {segmentationClasses.map((segmentationClass) => (
          <ListItem key={segmentationClass.label}>
            <Stack sx={{ width: 1 }}>
              <Stack
                direction="row"
                sx={{ justifyContent: "space-between", alignItems: "center" }}
              >
                <ListItemText>{segmentationClass.label}</ListItemText>
                <Circle sx={{ color: segmentationClass.color }} />
              </Stack>
              <List dense disablePadding sx={{ pl: 2 }}>
                {segmentationClass.scores.map((score, index) => (
                  <ListItem key={index} disablePadding>
                    <ListItemText>Score: {score}</ListItemText>
                  </ListItem>
                ))}
              </List>
            </Stack>
          </ListItem>
        ))}
      </List>
    </>
  );
}

interface SegmentationClass {
  label: string | null;
  color: string;
  scores: Array<number | null>;
}

function groupSegmentations(
  results: Maybe<ImageSegmentationInferenceResults>,
): Array<SegmentationClass> {
  const segmentationClassMap = new Map<string | null, SegmentationClass>();

  results?.segmentations.forEach((segmentation) => {
    let segmentationClass = segmentationClassMap.get(segmentation.label);
    if (segmentationClass === undefined) {
      const [r, g, b] = segmentation.color;

      segmentationClass = {
        label: segmentation.label,
        color: `rgb(${r} ${g} ${b})`,
        scores: [],
      };

      segmentationClassMap.set(segmentation.label, segmentationClass);
    }

    segmentationClass.scores.push(segmentation.score);
  });

  const segmentationClasses = Array.from(segmentationClassMap.values());

  segmentationClasses.forEach((segmentationClass) => {
    segmentationClass.scores.sort((a, b) => {
      if (a === null && b === null) {
        return 0;
      } else if (a === null) {
        return 1;
      } else if (b === null) {
        return -1;
      } else {
        return b - a;
      }
    });
  });

  segmentationClasses.sort((a, b) => {
    if (a.label === null && b.label === null) {
      return 0;
    } else if (a.label === null) {
      return 1;
    } else if (b.label === null) {
      return -1;
    } else {
      return a.label.localeCompare(b.label);
    }
  });

  return segmentationClasses;
}
