import React, { useEffect } from "react";
import { LoadingButton } from "@mui/lab";
import { Button, Stack, Typography } from "@mui/material";
import { useMutation } from "@tanstack/react-query";
import { useSnackbar } from "notistack";
import { useWatch } from "react-hook-form";
import { z } from "zod";
import { JsonField } from "~/components/DetailsCards/JsonField";
import type {
  SelectFieldProps,
  StudioFormValues,
  TextFieldProps,
  UseStudioFormReturn,
} from "~/components/Form";
import { SelectField, TextField, useStudioForm } from "~/components/Form";
import type { Option } from "~/components/Form/types";
import { createSafeContext } from "~/contexts";
import { optionalText, requiredEnum } from "~/domain/common";
import { invariant } from "~/lib/invariant";
import { startCase } from "~/lib/std";
import { useCurrentDataStore } from "~/lqs";
import type { Maybe } from "~/types";
import { getEventHandlerProps } from "~/utils";
import type { InitializedPanelNode, PanelNode } from "../../panels";
import { useInitializedPanel } from "../../panels";
import { usePlaybackSource } from "../../playback";
import { DetectionResultsControls } from "../detections";
import { ImageResultsControls } from "../images";
import { ImageSegmentationResultsControls } from "../segmentations";
import type { InferencePipelineTask, OnDemandInferenceResult } from "../types";
import { runOnDemandInference } from "./api";
import type { PipelineModelFieldProps } from "./pipeline-model-field";
import { PipelineModelField } from "./pipeline-model-field";
import { recordRecentModel } from "./recent-models";
import {
  isNullSnapshot,
  useInferenceStoreSnapshot,
} from "./use-inference-store-snapshot";
import type { OnDemandInferenceOperation } from "./use-operation-controller";
import {
  isPanelImageOperation,
  isPanelImageOperationSource,
  useOperationController,
} from "./use-operation-controller";

const TASK_ENUM_VALUES = [
  "object-detection",
  "image-segmentation",
  "depth-estimation",
] as const;

const taskOptions: ReadonlyArray<Option> = TASK_ENUM_VALUES.map((value) => ({
  label: startCase(value),
  value,
}));

const suggestedTaskModels = {
  "object-detection": ["LogQS-ML/detr-resnet-50"],
  "image-segmentation": ["facebook/detr-resnet-50-panoptic"],
  "depth-estimation": ["LiheYoung/depth-anything-small-hf"],
} satisfies Record<InferencePipelineTask, ReadonlyArray<string>>;

const taskSchema = requiredEnum(
  TASK_ENUM_VALUES,
) satisfies z.ZodType<InferencePipelineTask>;

const schema = z
  .object({
    task: taskSchema,
    model: optionalText,
    revision: optionalText,
  })
  // This destructuring is so TS can narrow `model` once it's confirmed it
  // won't be `null`.
  .transform(({ task, model, revision }, ctx) => {
    // Marking the `model` field as required in the base object schema would
    // cause an undesirable validation error when both it and `task` were
    // `null`: it doesn't look good showing an error on a disabled field the
    // user didn't have a chance to fill out yet. Instead, validate now that
    // it's not `null` since `task` is definitely not `null` by now.
    if (model === null) {
      ctx.addIssue({
        code: z.ZodIssueCode.custom,
        path: ["model"],
        message: "Field is required",
      });

      return z.NEVER;
    }

    return {
      task,
      model,
      revision,
    };
  });

export type _OnDemandInferenceFormValues = z.infer<typeof schema>;

type FormReturnValue = UseStudioFormReturn<_OnDemandInferenceFormValues>;

interface BaseOnDemandInference {
  operation: OnDemandInferenceOperation | null;
  handleClearPanel: (panel: PanelNode) => void;
  handleClosePanel: (panel: PanelNode) => void;
  handleChangePanelVisualization: (panel: PanelNode) => void;
  handleLoadLayout: () => void;
  disabled: boolean;
  formControl: FormReturnValue["control"];
  submitHandler: FormReturnValue["handleSubmit"];
  getTaskFieldProps: () => SelectFieldProps<
    StudioFormValues<_OnDemandInferenceFormValues>,
    "task"
  >;
  getModelFieldProps: (
    task: InferencePipelineTask | null,
  ) => PipelineModelFieldProps;
  getRevisionFieldProps: (
    model: string | null,
  ) => TextFieldProps<
    StudioFormValues<_OnDemandInferenceFormValues>,
    "revision"
  >;
}

interface IdleOnDemandInference extends BaseOnDemandInference {
  status: "idle";
}

interface PendingOnDemandInference extends BaseOnDemandInference {
  status: "pending";
}

interface RejectedOnDemandInference extends BaseOnDemandInference {
  status: "rejected";
}

interface FulfilledOnDemandInference extends BaseOnDemandInference {
  status: "fulfilled";
  value: OnDemandInferenceResult;
  clearResult: () => void;
}

export type OnDemandInference =
  | IdleOnDemandInference
  | PendingOnDemandInference
  | RejectedOnDemandInference
  | FulfilledOnDemandInference;

function useOnDemandInferenceProvider(): OnDemandInference {
  const { id: dataStoreId } = useCurrentDataStore();

  const playbackSource = usePlaybackSource();

  const operationController = useOperationController();

  const { snapshot } = useInferenceStoreSnapshot(operationController.operation);

  const imageTimestamp: Maybe<bigint> = isNullSnapshot(snapshot)
    ? null
    : snapshot.value?.current?.timestamp;

  const enabled =
    // Adding this check if the snapshot is an idle snapshot helps TS narrow
    // inside the mutation function so it knows the request will have the
    // image topic attached to it
    !isNullSnapshot(snapshot) &&
    !playbackSource.isLoading &&
    !playbackSource.isPlaying &&
    imageTimestamp != null;

  const mutation = useMutation({
    async mutationFn({
      task,
      model,
      revision,
    }: _OnDemandInferenceFormValues): Promise<OnDemandInferenceResult> {
      invariant(enabled, "Mutation not enabled");

      const result = await runOnDemandInference({
        dataStoreId,
        topicId: snapshot.request.imageTopic.id,
        timestamp: imageTimestamp,
        pipelineTask: task,
        pipelineModel: model,
        pipelineRevision: revision,
      });

      return {
        task,
        timestamp: imageTimestamp,
        ...result,
      } as OnDemandInferenceResult;
    },
  });

  const { enqueueSnackbar } = useSnackbar();

  const form = useStudioForm({
    schema,
    defaultValues: {
      task: null,
      model: null,
      revision: null,
    },
    onSubmit(values) {
      invariant(enabled, "Form not enabled");

      recordRecentModel(values.task, values.model);

      operationController.lockPanelImageOperationSource();

      mutation.mutate(values, {
        onError(error) {
          if (error instanceof z.ZodError) {
            enqueueSnackbar("Unrecognized inference result type", {
              variant: "error",
            });
          } else {
            enqueueSnackbar("Unable to run inference task", {
              variant: "error",
            });
          }
        },
      });
    },
  });

  function handleClearResult() {
    operationController.clearPanelImageOperationSource();
    mutation.reset();
  }

  useEffect(function resetWhenNoLongerRelevant() {
    if (!mutation.isSuccess) {
      return;
    }

    if (
      !playbackSource.isLoading &&
      !playbackSource.isPlaying &&
      mutation.data.timestamp === imageTimestamp
    ) {
      return;
    }

    handleClearResult();
  });

  function clearSourceIfNeeded(panel: PanelNode) {
    if (
      panel.isInitialized &&
      isPanelImageOperationSource(operationController.operation, panel)
    ) {
      operationController.clearPanelImageOperationSource();
      mutation.reset();
    }
  }

  const baseFields: BaseOnDemandInference = {
    operation: operationController.operation,
    handleClearPanel: clearSourceIfNeeded,
    handleClosePanel: clearSourceIfNeeded,
    handleChangePanelVisualization: clearSourceIfNeeded,
    handleLoadLayout() {
      if (isPanelImageOperation(operationController.operation)) {
        operationController.clearPanelImageOperationSource();
        mutation.reset();
      }
    },
    disabled: !enabled || mutation.isLoading,
    formControl: form.control,
    submitHandler: form.handleSubmit,
    getTaskFieldProps() {
      return {
        control: form.control,
        name: "task",
        required: true,
        options: taskOptions,
        disabled: mutation.isLoading,
        onChange(): void {
          form.resetField("model");
          form.resetField("revision");
        },
      };
    },
    getModelFieldProps(task) {
      const taskSelected = task !== null;
      const suggestedModels = taskSelected ? suggestedTaskModels[task] : [];

      return {
        // Ensure internal autocomplete state is reset when the task changes
        key: task,
        control: form.control,
        name: "model",
        task,
        required: true,
        disabled: !taskSelected || mutation.isLoading,
        suggestedModels,
        onChange(): void {
          form.resetField("revision");
        },
      };
    },
    getRevisionFieldProps(model) {
      const modelSelected = model !== null;

      return {
        control: form.control,
        name: "revision",
        disabled: !modelSelected || mutation.isLoading,
      };
    },
  };

  // Once an on-demand result is visible, it can become "irrelevant" in some
  // cases:
  //   1. Playback is no longer paused
  //   2. The current image frame's timestamp isn't the same as the one used
  //      to create the on-demand result
  const isResultRelevant =
    mutation.isSuccess &&
    !playbackSource.isLoading &&
    !playbackSource.isPlaying &&
    mutation.data.timestamp === imageTimestamp;

  if (isResultRelevant) {
    return {
      status: "fulfilled",
      ...baseFields,
      value: mutation.data,
      clearResult: handleClearResult,
    };
  } else {
    return {
      status: mutation.isError
        ? "rejected"
        : mutation.isLoading
          ? "pending"
          : "idle",
      ...baseFields,
    };
  }
}

export const [useOnDemandInference, OnDemandInferenceContext] =
  createSafeContext<OnDemandInference>("OnDemandInference");

/**
 * Maintains on-demand inference state to be shared between visualizations and
 * controls across the Player.
 */
export function OnDemandInferenceProvider({
  children,
}: {
  children: React.ReactNode;
}) {
  const onDemandInference = useOnDemandInferenceProvider();

  return (
    <OnDemandInferenceContext.Provider value={onDemandInference}>
      {children}
    </OnDemandInferenceContext.Provider>
  );
}

interface DisabledOnDemandInference extends BaseOnDemandInference {
  status: "disabled";
  sourceStatus: OnDemandInference["status"];
  disabled: true;
}

export function usePanelOnDemandInference(
  panel: InitializedPanelNode,
): OnDemandInference | DisabledOnDemandInference {
  const onDemandInference = useOnDemandInference();

  if (isPanelImageOperationSource(onDemandInference.operation, panel)) {
    return onDemandInference;
  } else {
    return {
      ...onDemandInference,
      status: "disabled",
      sourceStatus: onDemandInference.status,
      disabled: true,
    };
  }
}

export function OnDemandInferenceControls() {
  const panel = useInitializedPanel();

  const onDemandInference = usePanelOnDemandInference(panel);

  const {
    formControl,
    submitHandler,
    getTaskFieldProps,
    getModelFieldProps,
    getRevisionFieldProps,
  } = onDemandInference;

  const watchedTask = useWatch({
    control: formControl,
    name: "task",
  });
  const watchedModel = useWatch({
    control: formControl,
    name: "model",
  });

  const clearResultHandlerProps = getEventHandlerProps(
    "onClick",
    onDemandInference.status === "fulfilled" && onDemandInference.clearResult,
  );

  let taskSpecificControls = null;
  if (onDemandInference.status === "fulfilled") {
    switch (onDemandInference.value.task) {
      case "object-detection": {
        taskSpecificControls = (
          <DetectionResultsControls
            panel={panel}
            results={onDemandInference.value.data.output}
          />
        );
        break;
      }
      case "image-segmentation": {
        taskSpecificControls = (
          <ImageSegmentationResultsControls
            panel={panel}
            results={onDemandInference.value.data.output}
          />
        );
        break;
      }
      case "depth-estimation": {
        taskSpecificControls = <ImageResultsControls panel={panel} />;
        break;
      }
    }
  }

  return (
    <>
      <Typography variant="h6" component="p" sx={{ mb: 2 }}>
        On-Demand Inference
      </Typography>
      <Stack spacing={2} component="form" onSubmit={submitHandler} noValidate>
        <SelectField {...getTaskFieldProps()} />
        <PipelineModelField {...getModelFieldProps(watchedTask)} />
        <TextField {...getRevisionFieldProps(watchedModel)} />
        <LoadingButton
          type="submit"
          color="primary"
          variant="contained"
          fullWidth
          disableElevation
          disabled={onDemandInference.disabled}
          loading={onDemandInference.status === "pending"}
        >
          Run Task
        </LoadingButton>
        {onDemandInference.status === "disabled" &&
          onDemandInference.sourceStatus === "pending" && (
            <Typography>Running inference against a different panel</Typography>
          )}
        {onDemandInference.status === "disabled" &&
          onDemandInference.sourceStatus === "fulfilled" && (
            <Typography>
              Showing inference results for a different panel
            </Typography>
          )}
        {onDemandInference.status === "fulfilled" && (
          <>
            <JsonField value={onDemandInference.value.data.meta} />
            <Button
              type="button"
              color="secondary"
              variant="contained"
              fullWidth
              disableElevation
              {...clearResultHandlerProps}
            >
              Clear Result
            </Button>
          </>
        )}
      </Stack>
      {taskSpecificControls}
    </>
  );
}
