import React, { useId, useMemo, useState } from "react";
import { InsertChart } from "@mui/icons-material";
import {
  Backdrop,
  Box,
  Popper,
  styled,
  ToggleButton,
  Tooltip,
  Typography,
} from "@mui/material";
import type { AnimatedLineProps, ChartContainerProps } from "@mui/x-charts";
import {
  ChartContainer,
  ChartsClipPath,
  ChartsGrid,
  ChartsLegend,
  ChartsReferenceLine,
  ChartsTooltipCell,
  chartsTooltipClasses,
  ChartsTooltipMark,
  ChartsTooltipPaper,
  ChartsTooltipRow,
  ChartsTooltipTable,
  ChartsXAxis,
  ChartsYAxis,
  LineHighlightElement,
  LinePlot,
  MarkPlot,
  useDrawingArea,
  useXScale,
  useYScale,
} from "@mui/x-charts";
import type { StrictOmit } from "ts-essentials";
import useResizeObserver from "use-resize-observer";
import { Loading } from "~/components/Loading";
import { ErrorMessage } from "~/components/error-message";
import { createSafeContext } from "~/contexts";
import { secondsToNanoseconds, utcToRelativeNanoseconds } from "~/lib/dates";
import { get } from "~/lib/std";
import type { Record, Topic } from "~/lqs";
import type { TimeRange } from "~/types";
import { LoadingFeedback, PanelLayout } from "../../components";
import type { PointerLocation } from "../../hooks";
import {
  usePointerLocation,
  useSkipToFirstTimestamp,
  useUpdatePanelBuffering,
} from "../../hooks";
import type { ChartPanel } from "../../panels";
import { getPrimaryTopicDescriptor, VisualizationType } from "../../panels";
import {
  useFormatPlaybackTimestamp,
  useLoadedPlaybackSource,
} from "../../playback";
import type { PlayerRecord } from "../../record-store";
import { useAllRecords, useRecords } from "../../record-store";
import { calculateRecordWindow } from "../../utils";
import { useVisualizationStoreParams } from "../context";
import { calculateWindowTicks } from "../utils";
import lookupFieldUnit from "./lookup-field-unit";
import { getFieldStroke } from "./utils";

const WINDOW_SIZE = secondsToNanoseconds(30);

export function ChartVisualization({
  panel,
  topic,
}: {
  panel: ChartPanel;
  topic: Topic;
}) {
  const { ref, height, width } = useResizeObserver();

  const [showOverview, setShowOverview] = useState(false);

  const playbackSource = useLoadedPlaybackSource();

  const formatPlaybackTimestamp = useFormatPlaybackTimestamp();

  const storeParams = useVisualizationStoreParams(VisualizationType.Chart);
  const { snapshot, isPlaceholder } = useRecords({
    recordType: "default",
    topic,
    ...storeParams,
  });

  useSkipToFirstTimestamp(panel, topic, snapshot.status === "fulfilled");

  const allRecordsQuery = useAllRecords({ topic, enabled: showOverview });

  const areRecordsBuffering = snapshot.status === "pending" || isPlaceholder;

  // TODO: Properly handle this when visual bugs in overview are fixed
  const isOverviewBuffering = false;
  // const isOverviewBuffering =
  //   showOverview &&
  //   (recordsOverview.status === "idle" || recordsOverview.status === "loading");

  useUpdatePanelBuffering(areRecordsBuffering || isOverviewBuffering);

  const { fields } = getPrimaryTopicDescriptor(panel);

  let content = null;
  if (snapshot.status === "pending") {
    content = <LoadingFeedback description="data points" />;
  } else if (snapshot.status === "rejected") {
    content = (
      <ErrorMessage>An error occurred. Can't get chart data</ErrorMessage>
    );
  } else if (width != null && height != null) {
    const overviewChartHeight = showOverview ? Math.round(height * 0.2) : 0;
    const windowChartHeight = height - overviewChartHeight;

    let overview = null;
    if (showOverview) {
      if (allRecordsQuery.status === "loading") {
        overview = (
          <Box sx={{ width: "100%", height: overviewChartHeight }}>
            <Loading type="circular" />
          </Box>
        );
      } else if (allRecordsQuery.status === "error") {
        overview = (
          <Box sx={{ width: "100%", height: overviewChartHeight }}>
            <ErrorMessage>
              An error occurred. Unable to show overview chart
            </ErrorMessage>
          </Box>
        );
      } else {
        overview = (
          <RequestTimestampContext.Provider value={snapshot.request.timestamp}>
            <MemoizedOverviewChart
              playbackBounds={playbackSource.bounds}
              formatPlaybackTimestamp={formatPlaybackTimestamp}
              topic={topic}
              width={width}
              height={overviewChartHeight}
              chartFields={fields}
              records={allRecordsQuery.data}
            />
          </RequestTimestampContext.Provider>
        );
      }
    }

    content = (
      <Box
        sx={{
          width: "100%",
          height: "100%",
          overflow: "hidden",
          position: "relative",
        }}
      >
        <WindowChart
          topic={topic}
          width={width}
          height={windowChartHeight}
          chartFields={fields}
          records={snapshot.value}
          timestamp={snapshot.request.timestamp}
        />
        {overview}
        <Tooltip title="Toggle overview chart">
          <ToggleButton
            sx={{
              position: "absolute",
              left: (theme) => theme.spacing(1),
              bottom: (theme) => theme.spacing(1),
            }}
            value={true}
            selected={showOverview}
            onChange={() => setShowOverview(!showOverview)}
          >
            <InsertChart />
          </ToggleButton>
        </Tooltip>
        {isPlaceholder ? (
          <LoadingFeedback description="data points" />
        ) : fields.length === 0 ? (
          <Backdrop open sx={{ position: "absolute" }}>
            <Typography variant="h5" component="p">
              Select fields to plot in the panel controls
            </Typography>
          </Backdrop>
        ) : null}
      </Box>
    );
  }

  return <PanelLayout contentRef={ref}>{content}</PanelLayout>;
}

const Rect = styled("rect")(({ theme }) => ({
  fill: theme.palette.secondary.main,
  fillOpacity: 0.5,
  height: "100%",
}));

// `<WindowReferenceArea />` needs the store request's timestamp to draw the
// reference area. Passing this as a prop isn't great though because it'll
// change pretty frequently, making the overview chart's memoization pointless.
// Instead, use a context which will provide the store request's timestamp from
// the `<ChartVisualization />` to the `<WindowReferenceArea />` without the
// `<MemoizedOverviewChart />` needing to rerender
const [useRequestTimestampContext, RequestTimestampContext] =
  createSafeContext<bigint>("RequestTimestamp");

function WindowReferenceArea() {
  const timestamp = useRequestTimestampContext();

  const playbackSource = useLoadedPlaybackSource();

  const recordWindow = calculateRecordWindow(
    WINDOW_SIZE,
    timestamp,
    playbackSource.bounds,
  );

  const xScale = useXScale<"linear">();

  const x1 = xScale(
    utcToRelativeNanoseconds(
      recordWindow.startTime,
      playbackSource.bounds.startTime,
    ),
  );
  const x2 = xScale(
    utcToRelativeNanoseconds(
      recordWindow.endTime,
      playbackSource.bounds.startTime,
    ),
  );

  return <Rect x={x1} width={x2 - x1} />;
}

const Path = styled("path")<AnimatedLineProps>(({ ownerState }) => ({
  fill: "none",
  stroke: ownerState.color,
  strokeWidth: 2,
}));

const MemoizedOverviewChart = React.memo(function MemoizedOverviewChart({
  playbackBounds,
  formatPlaybackTimestamp,
  topic,
  width,
  height,
  chartFields,
  records,
}: {
  playbackBounds: TimeRange;
  formatPlaybackTimestamp: (timestamp: bigint) => string;
  topic: Topic;
  width: number;
  height: number;
  chartFields: ReadonlyArray<string>;
  records: ReadonlyArray<Record>;
}) {
  const id = useId();
  const clipPathId = `${id}-clip-path`;

  const chartProps: StrictOmit<ChartContainerProps, "width" | "height"> =
    useMemo(() => {
      const chartData = selectChartData(
        topic.typeName,
        chartFields,
        records,
        playbackBounds.startTime,
      );

      return {
        xAxis: [
          {
            data: chartData.timestamps,
            valueFormatter: formatPlaybackTimestamp,
            min: 0,
            max: Number(playbackBounds.endTime - playbackBounds.startTime),
            tickInterval: calculateOverviewTicks(playbackBounds, 5),
          },
        ],
        series: chartData.fieldSeries.map((series) => ({
          ...series,
          type: "line",
          curve: "linear",
        })),
        margin: {
          top: 10,
          right: 10,
          bottom: 40,
          left: 100,
        },
        skipAnimation: true,
        disableAxisListener: true,
        children: (
          <>
            <ChartsClipPath id={clipPathId} />
            <ChartsXAxis />
            <ChartsGrid vertical />
            <g clipPath={`url(#${clipPathId})`}>
              <LinePlot
                slots={{
                  line: Path,
                }}
              />
              <WindowReferenceArea />
            </g>
          </>
        ),
      };
    }, [
      topic.typeName,
      chartFields,
      records,
      formatPlaybackTimestamp,
      playbackBounds,
      clipPathId,
    ]);

  return <ChartContainer width={width} height={height} {...chartProps} />;
});

function WindowChart({
  topic,
  width,
  height,
  chartFields,
  records,
  timestamp,
}: {
  topic: Topic;
  width: number;
  height: number;
  chartFields: ReadonlyArray<string>;
  records: ReadonlyArray<PlayerRecord<"default">>;
  timestamp: bigint;
}) {
  const id = useId();
  const clipPathId = `${id}-clip-path`;

  const playbackSource = useLoadedPlaybackSource();

  const formatPlaybackTimestamp = useFormatPlaybackTimestamp();

  const chartData = selectChartData(
    topic.typeName,
    chartFields,
    records,
    playbackSource.bounds.startTime,
  );

  const recordWindow = calculateRecordWindow(
    WINDOW_SIZE,
    timestamp,
    playbackSource.bounds,
  );

  const domain = {
    min: utcToRelativeNanoseconds(
      recordWindow.startTime,
      playbackSource.bounds.startTime,
    ),
    max: utcToRelativeNanoseconds(
      recordWindow.endTime,
      playbackSource.bounds.startTime,
    ),
  };

  return (
    <ChartContainer
      width={width}
      height={height}
      xAxis={[
        {
          data: chartData.timestamps,
          valueFormatter: formatPlaybackTimestamp,
          ...domain,
          tickInterval: calculateWindowTicks(
            recordWindow,
            secondsToNanoseconds(5),
            WINDOW_SIZE,
            playbackSource.bounds.startTime,
          ),
        },
      ]}
      series={chartData.fieldSeries.map((series) => ({
        ...series,
        type: "line",
        curve: "linear",
      }))}
      margin={{
        right: 10,
        bottom: 30,
        left: 100,
      }}
      skipAnimation
      disableAxisListener
    >
      <ChartsClipPath id={clipPathId} offset={{ top: 5, bottom: 5 }} />
      <ChartsLegend />
      <ChartsXAxis />
      <ChartsYAxis />
      <ChartsGrid vertical horizontal />
      <g clipPath={`url(#${clipPathId})`}>
        <LinePlot
          slots={{
            line: Path,
          }}
        />
        <MarkPlot />
        <ChartsReferenceLine
          x={utcToRelativeNanoseconds(
            timestamp,
            playbackSource.bounds.startTime,
          )}
        />
        <PointerPositionIndicator chartData={chartData} domain={domain} />
      </g>
    </ChartContainer>
  );
}

const AxisHighlight = styled("line")(({ theme }) => ({
  pointerEvents: "none",
  fill: "none",
  stroke: theme.palette.text.primary,
  strokeWidth: 1,
  strokeDasharray: "5 2",
}));

//   MUI's interaction indicators - the axis highlight, line mark highlights,
// and tooltip - don't work as expected when the records change as a result of
// playback time changing. Their internal state is structured in such a way
// that even the tooltip title, the formatted playback timestamp, can get out
// of sync with the record whose values are shown in the tooltip.
//   This component rolls most of the logic itself by tracking and storing the
// pointer position alone and always providing highlights and a tooltip for the
// record closest to the pointer location, even when the underlying records
// change.
function PointerPositionIndicator({
  chartData,
  domain,
}: {
  chartData: ChartData;
  domain: { min: number; max: number };
}) {
  const drawingArea = useDrawingArea();

  const [pointerLocation, setPointerLocation] =
    useState<PointerLocation | null>(null);
  const pointerContainerProps = usePointerLocation({
    trigger: "pointerover",
    onChange: setPointerLocation,
    onInteractionEnd() {
      setPointerLocation(null);
    },
  });

  const formatPlaybackTimestamp = useFormatPlaybackTimestamp();

  const xScale = useXScale<"linear">();
  const yScale = useYScale<"linear">();

  let closestPointIndex: number | null = null;
  if (pointerLocation != null) {
    // The chart may include some records not in the visible domain. This is
    // intentional so part of the line doesn't disappear when a record reaches
    // either end of the domain, but those records outside the domain shouldn't
    // be considered when checking which record the pointer is closest to.
    const inDomainTimestamps = chartData.timestamps.filter(
      (timestamp) => domain.min <= timestamp && timestamp <= domain.max,
    );

    if (inDomainTimestamps.length > 0) {
      // The x-scale's range is equivalent to the drawing area's coordinates -
      // [left, left + width] - effectively being relative to the parent <svg>,
      // but the pointer location is already relative to the drawing area so
      // it needs to be transformed back into the <svg>'s coordinate frame.
      const value = xScale.invert(pointerLocation.x + drawingArea.left);

      const mostRecentPointIndex = inDomainTimestamps.findLastIndex(
        (timestamp) => timestamp <= value,
      );

      if (mostRecentPointIndex === -1) {
        // Pointer location is before the first record, so the first record
        // is the closest.
        closestPointIndex = 0;
      } else if (mostRecentPointIndex === inDomainTimestamps.length - 1) {
        // Pointer location is at or after the last record, so the last record
        // is the closest.
        closestPointIndex = mostRecentPointIndex;
      } else {
        // The two previous branches implicitly account for there only being one
        // record: if there was only one record, it would be both the earliest
        // (branch 1) and the last (branch 2) and the pointer location would
        // definitely be before, at or after that record. This implies there's
        // at least two records and we need to calculate which one the pointer
        // is closest to.

        const mostRecentPointDistance =
          value - inDomainTimestamps[mostRecentPointIndex];

        const nextPointIndex = mostRecentPointIndex + 1;
        const nextPointDistance = inDomainTimestamps[nextPointIndex] - value;

        if (mostRecentPointDistance <= nextPointDistance) {
          closestPointIndex = mostRecentPointIndex;
        } else {
          closestPointIndex = nextPointIndex;
        }
      }
    }

    if (closestPointIndex != null) {
      // `closestPointIndex` was calculated from the records inside the visible
      // domain, so it needs to be adjusted to be an index into the full set of
      // timestamps. Timestamps are unique so we just need to find the index of
      // the closest timestamp in the full list of records.
      closestPointIndex = chartData.timestamps.indexOf(
        inDomainTimestamps[closestPointIndex],
      );
    }
  }

  let indicators = null;
  if (closestPointIndex != null) {
    const timestamp = chartData.timestamps[closestPointIndex];
    const xCoordinate = xScale(timestamp);

    indicators = {
      timestamp,
      xCoordinate,
      points: chartData.fieldSeries.map((series) => {
        const datum = series.data[closestPointIndex];

        return {
          key: series.id,
          id: series.id,
          label: series.label,
          color: series.color,
          value: datum,
        };
      }),
      tooltip: {
        location: {
          width: 0,
          height: 0,
          x: pointerLocation?.clientX ?? 0,
          y: pointerLocation?.clientY ?? 0,
          top: pointerLocation?.clientY ?? 0,
          right: pointerLocation?.clientX ?? 0,
          bottom: pointerLocation?.clientY ?? 0,
          left: pointerLocation?.clientX ?? 0,
          toJSON: () => "",
        },
      },
    };
  }

  return (
    <>
      <rect
        {...pointerContainerProps}
        x={drawingArea.left}
        y={drawingArea.top}
        width={drawingArea.width}
        height={drawingArea.height}
        fill="transparent"
      />
      {closestPointIndex != null && indicators != null && (
        <>
          <AxisHighlight
            x1={indicators.xCoordinate}
            y1={yScale.range()[0]}
            x2={indicators.xCoordinate}
            y2={yScale.range()[1]}
          />
          {indicators.points.map((point) =>
            point.value == null ? null : (
              <LineHighlightElement
                key={point.key}
                id={point.id}
                color={point.color}
                x={indicators.xCoordinate}
                y={yScale(point.value)}
              />
            ),
          )}
          <Popper
            open
            anchorEl={{
              getBoundingClientRect: () => indicators.tooltip.location,
            }}
            placement="right-start"
            className={chartsTooltipClasses.root}
            sx={{
              pointerEvents: "none",
              zIndex: (theme) => theme.zIndex.modal,
            }}
          >
            <ChartsTooltipPaper className={chartsTooltipClasses.paper}>
              <ChartsTooltipTable className={chartsTooltipClasses.table}>
                <thead>
                  <ChartsTooltipRow>
                    <ChartsTooltipCell colSpan={3}>
                      <Typography>
                        {formatPlaybackTimestamp(indicators.timestamp)}
                      </Typography>
                    </ChartsTooltipCell>
                  </ChartsTooltipRow>
                </thead>
                <tbody>
                  {indicators.points.map((point) => (
                    <ChartsTooltipRow
                      key={point.key}
                      className={chartsTooltipClasses.row}
                    >
                      <ChartsTooltipCell
                        className={`${chartsTooltipClasses.cell} ${chartsTooltipClasses.markCell}`}
                      >
                        <ChartsTooltipMark
                          className={chartsTooltipClasses.mark}
                          color={point.color}
                        />
                      </ChartsTooltipCell>
                      <ChartsTooltipCell
                        className={`${chartsTooltipClasses.cell} ${chartsTooltipClasses.labelCell}`}
                      >
                        <Typography>{point.label}</Typography>
                      </ChartsTooltipCell>
                      <ChartsTooltipCell
                        className={`${chartsTooltipClasses.cell} ${chartsTooltipClasses.valueCell}`}
                      >
                        {point.value ?? "-"}
                      </ChartsTooltipCell>
                    </ChartsTooltipRow>
                  ))}
                </tbody>
              </ChartsTooltipTable>
            </ChartsTooltipPaper>
          </Popper>
        </>
      )}
    </>
  );
}

interface ChartData {
  timestamps: Array<number>;
  fieldSeries: ReadonlyArray<{
    id: string;
    label: string;
    color: string;
    data: Array<number | null>;
  }>;
}

function selectChartData(
  topicTypeName: Topic["typeName"],
  chartFields: ReadonlyArray<string>,
  records: ReadonlyArray<PlayerRecord<"default">> | ReadonlyArray<Record>,
  startTime: bigint,
): ChartData {
  const timestamps = new Array<number>();
  const fieldSeries = chartFields.map((chartField, index) => {
    const fieldUnit = lookupFieldUnit(topicTypeName, chartField);
    const label =
      fieldUnit == null ? chartField : `${chartField} (${fieldUnit})`;

    return {
      id: chartField,
      label,
      color: getFieldStroke(index),
      data: new Array<number | null>(),
    };
  });

  for (const record of records) {
    timestamps.push(utcToRelativeNanoseconds(record.timestamp, startTime));

    chartFields.forEach((chartField, index) => {
      const dataField = "queryData" in record ? record.queryData : record.data;
      const value: unknown = get(dataField, chartField);

      fieldSeries[index].data.push(typeof value === "number" ? value : null);
    });
  }

  return {
    timestamps,
    fieldSeries,
  };
}

function calculateOverviewTicks(
  playerBounds: TimeRange,
  tickCount: number,
): Array<number> {
  const durationNs = Number(playerBounds.endTime - playerBounds.startTime);
  const intervalNs = Math.floor(durationNs / (tickCount + 1));

  const ticks = new Array<number>();

  for (let tick = intervalNs; tick < durationNs; tick += intervalNs) {
    ticks.push(tick);
  }

  return ticks;
}
