import * as z from "zod";
import { chunk } from "~/lib/std";

export type ClassId = number | string;

interface Polygon {
  points: ReadonlyArray<{
    x: number;
    y: number;
  }>;
}

export interface Detection {
  classId: ClassId;
  className: string;
  boundingBox: {
    topLeftX: number;
    topLeftY: number;
    width: number;
    height: number;
  };
  polygons: ReadonlyArray<Polygon>;
}

export interface DetectionInferenceResults {
  detections: ReadonlyArray<Detection>;
  imageWidth: number;
  imageHeight: number;
}

const cocoBboxSchema = z.tuple([
  z.number(),
  z.number(),
  z.number(),
  z.number(),
]);

const cocoSegmentationSchema = z
  .array(z.number())
  .refine(
    (value) => value.length % 2 === 0,
    "Expected an even number of points",
  );

function transformSegmentationToPolygon(
  segmentation: ReadonlyArray<number>,
): Polygon {
  return {
    points: chunk(segmentation, 2).map(([x, y]) => ({ x, y })),
  };
}

const standardCocoSchema = z
  .object({
    images: z
      .array(
        z.object({
          width: z.number(),
          height: z.number(),
        }),
      )
      .nonempty(),
    categories: z.array(
      z.object({
        id: z.number(),
        name: z.string(),
      }),
    ),
    annotations: z.array(
      z.object({
        bbox: cocoBboxSchema,
        category_id: z.number(),
        segmentation: z.array(cocoSegmentationSchema),
      }),
    ),
  })
  .transform((value, ctx): DetectionInferenceResults => {
    const {
      annotations,
      categories,
      // Assuming there's only 1 image
      images: [{ width: imageWidth, height: imageHeight }],
    } = value;

    const categoryMap = new Map(categories.map(({ id, name }) => [id, name]));

    const detections = new Array<Detection>();
    for (const annotation of annotations) {
      const {
        bbox: [topLeftX, topLeftY, bboxWidth, bboxHeight],
        segmentation,
        category_id: classId,
      } = annotation;

      const className = categoryMap.get(classId);
      if (className === undefined) {
        ctx.addIssue({
          code: z.ZodIssueCode.custom,
          message: `No category found with ID ${classId}`,
          fatal: true,
        });

        return z.NEVER;
      }

      detections.push({
        classId,
        className,
        boundingBox: {
          topLeftX,
          topLeftY,
          width: bboxWidth,
          height: bboxHeight,
        },
        polygons: segmentation.map(transformSegmentationToPolygon),
      });
    }

    return {
      imageWidth,
      imageHeight,
      detections,
    };
  });

const crlCocoSchema = z
  .object({
    img_attributes: z.object({
      width: z.number(),
      height: z.number(),
    }),
    annotations: z.array(
      z.object({
        bbox: cocoBboxSchema,
        category: z.object({
          id: z.number(),
          name: z.string(),
        }),
        segmentations: z.array(
          z.object({
            segmentation: cocoSegmentationSchema,
          }),
        ),
      }),
    ),
  })
  .transform((value): DetectionInferenceResults => {
    const {
      annotations,
      img_attributes: { width: imageWidth, height: imageHeight },
    } = value;

    return {
      imageWidth,
      imageHeight,
      detections: annotations.map((annotation): Detection => {
        const {
          bbox: [topLeftX, topLeftY, bboxWidth, bboxHeight],
          segmentations,
          category: { id: classId, name: className },
        } = annotation;

        return {
          classId,
          className,
          boundingBox: {
            topLeftX,
            topLeftY,
            width: bboxWidth,
            height: bboxHeight,
          },
          polygons: segmentations.map(({ segmentation }) =>
            transformSegmentationToPolygon(segmentation),
          ),
        };
      }),
    };
  });

const huggingFaceSchema = z
  .object({
    meta: z.object({
      image: z.object({
        width: z.number(),
        height: z.number(),
      }),
    }),
    output: z.array(
      z.object({
        label: z.string(),
        box: z.object({
          xmin: z.number(),
          ymin: z.number(),
          xmax: z.number(),
          ymax: z.number(),
        }),
      }),
    ),
  })
  .transform(({ meta, output }): DetectionInferenceResults => {
    return {
      imageWidth: meta.image.width,
      imageHeight: meta.image.height,
      detections: output.map((detection) => ({
        classId: detection.label,
        className: detection.label,
        boundingBox: {
          topLeftX: detection.box.xmin,
          topLeftY: detection.box.ymin,
          width: detection.box.xmax - detection.box.xmin,
          height: detection.box.ymax - detection.box.ymin,
        },
        polygons: [],
      })),
    };
  });

export const detectionResultsSchema = z.union([
  standardCocoSchema,
  crlCocoSchema,
  huggingFaceSchema,
]);
