import React, { useId, useMemo, useState } from "react";
import {
  Backdrop,
  Box,
  Button,
  Popper,
  styled,
  Typography,
} from "@mui/material";
import type { AnimatedLineProps, ChartContainerProps } from "@mui/x-charts";
import {
  ChartContainer,
  ChartsClipPath,
  ChartsGrid,
  ChartsLegend,
  ChartsReferenceLine,
  ChartsTooltipCell,
  chartsTooltipClasses,
  ChartsTooltipMark,
  ChartsTooltipPaper,
  ChartsTooltipRow,
  ChartsTooltipTable,
  ChartsXAxis,
  ChartsYAxis,
  LineHighlightElement,
  LinePlot,
  MarkPlot,
  useDrawingArea,
  useXScale,
  useYScale,
} from "@mui/x-charts";
import type { StrictOmit } from "ts-essentials";
import useResizeObserver from "use-resize-observer";
import { ErrorMessage } from "~/components/error-message";
import { utcToRelativeNanoseconds } from "~/lib/dates";
import { get } from "~/lib/std";
import type { TimeRange } from "~/types";
import { usePlayerActions } from "../../actions";
import { LoadingFeedback, PanelLayout } from "../../components";
import type { PointerLocation } from "../../hooks";
import { usePointerLocation } from "../../hooks";
import type { ChartPanel } from "../../panels";
import {
  useFormatPlaybackTimestamp,
  useLoadedPlaybackSource,
} from "../../playback";
import type { PlayerRecord } from "../../record-store";
import type { TimestepValue } from "../../types";
import { calculateRecordWindow } from "../../utils";
import { calculateWindowTicks, getWindowSizeForTimestep } from "../utils";
import type { ChartRecords } from "./use-chart-records";
import { ChartError, useChartRecords } from "./use-chart-records";

export function ChartVisualization({ panel }: { panel: ChartPanel }) {
  const { ref, height, width } = useResizeObserver();

  const playerActions = usePlayerActions();

  const [chartRecordsSnapshot, isPlaceholderSnapshot] = useChartRecords({
    panel,
  });

  function handleCloseOverviewChart(): void {
    // TODO: Clear store error?
    playerActions.toggleOverviewChart(panel, false);
  }

  let content = null;
  if (chartRecordsSnapshot.status === "pending") {
    content = <LoadingFeedback description="data points" />;
  } else if (chartRecordsSnapshot.status === "rejected") {
    if (
      chartRecordsSnapshot.reason instanceof ChartError &&
      chartRecordsSnapshot.reason.source === "overview"
    ) {
      content = (
        <ErrorMessage disableTypography>
          <ErrorMessage.Paragraph>
            An error occurred getting overview chart data
          </ErrorMessage.Paragraph>
          <Button
            color="primary"
            variant="outlined"
            onClick={handleCloseOverviewChart}
          >
            Close overview chart
          </Button>
        </ErrorMessage>
      );
    } else {
      content = (
        <ErrorMessage>An error occurred. Can't get chart data</ErrorMessage>
      );
    }
  } else if (width != null && height != null) {
    const { value } = chartRecordsSnapshot;

    const overviewChartHeight = value.showOverview
      ? Math.round(height * 0.2)
      : 0;
    const windowChartHeight = height - overviewChartHeight;

    let overview = null;
    if (value.showOverview) {
      overview = (
        <OverviewChart
          width={width}
          height={overviewChartHeight}
          timestep={value.timestep}
          timestamp={value.timestamp}
          sources={value.sources}
        />
      );
    }

    content = (
      <Box
        sx={{
          width: "100%",
          height: "100%",
          overflow: "hidden",
          position: "relative",
        }}
      >
        <WindowChart
          width={width}
          height={windowChartHeight}
          timestep={value.timestep}
          timestamp={value.timestamp}
          sources={value.sources}
        />
        {overview}
        {isPlaceholderSnapshot ? (
          <LoadingFeedback description="data points" />
        ) : value.sources.length === 0 ? (
          <Backdrop open sx={{ position: "absolute" }}>
            <Typography variant="h5" component="p">
              Select fields to plot in the panel controls
            </Typography>
          </Backdrop>
        ) : null}
      </Box>
    );
  }

  return <PanelLayout contentRef={ref}>{content}</PanelLayout>;
}

const Rect = styled("rect")(({ theme }) => ({
  fill: theme.palette.secondary.main,
  fillOpacity: 0.5,
  height: "100%",
}));

function WindowReferenceArea({
  timestep,
  timestamp,
}: Pick<ChartRecords, "timestep" | "timestamp">) {
  const windowSize = getWindowSizeForTimestep(timestep);

  const playbackSource = useLoadedPlaybackSource();

  const recordWindow = calculateRecordWindow(
    windowSize,
    timestamp,
    playbackSource.bounds,
  );

  const xScale = useXScale<"linear">();

  const x1 = xScale(
    utcToRelativeNanoseconds(
      recordWindow.startTime,
      playbackSource.bounds.startTime,
    ),
  );
  const x2 = xScale(
    utcToRelativeNanoseconds(
      recordWindow.endTime,
      playbackSource.bounds.startTime,
    ),
  );

  return <Rect x={x1} width={x2 - x1} />;
}

const Path = styled("path")<AnimatedLineProps>(({ ownerState }) => ({
  fill: "none",
  stroke: ownerState.color,
  strokeWidth: 2,
}));

const commonOverviewChartProps = {
  margin: {
    top: 10,
    right: 10,
    bottom: 40,
    left: 100,
  },
  skipAnimation: true,
  disableAxisListener: true,
} satisfies Partial<ChartContainerProps>;

function OverviewChart({
  width,
  height,
  timestep,
  timestamp,
  sources,
}: {
  width: number;
  height: number;
} & Pick<ChartRecords, "timestep" | "timestamp" | "sources">) {
  const { bounds: playbackBounds } = useLoadedPlaybackSource();

  const formatPlaybackTimestamp = useFormatPlaybackTimestamp();

  const id = useId();
  const clipPathId = `${id}-clip-path`;

  // The overview chart is composed of several sub-charts. The first is a base
  // chart with no real data that just renders the axes and reference area.
  // This chart can re-render often with little performance impact since there's
  // no data to be processed. Remaining sub-charts are rendered per chart source
  // and are memoized to reduce the performance impact when unrelated data like
  // the playback timestamp changes. These sub-charts are absolutely positioned
  // on top of the base chart and don't render axes of their own since the base
  // chart renders those.
  //
  // The reason for this roundabout method is because an individual source's
  // array of records is memoizable but the array of sources is not. Trying to
  // render this unstable array of sources in the same chart would cause massive
  // performance issues each time the timestamp changes because MUI's charts
  // reprocess series data each time the `series` prop changes, even if the
  // underlying series' `data` hasn't changed which is the case here.
  return (
    <Box sx={{ width, height, position: "relative" }}>
      <ChartContainer
        {...commonOverviewChartProps}
        width={width}
        height={height}
        xAxis={[
          {
            data: [],
            valueFormatter(value) {
              return formatPlaybackTimestamp(value);
            },
            min: 0,
            max: Number(playbackBounds.endTime - playbackBounds.startTime),
            tickInterval: calculateOverviewTicks(playbackBounds, 5),
          },
        ]}
        series={[]}
      >
        <ChartsClipPath id={clipPathId} />
        <ChartsXAxis />
        <ChartsGrid vertical />
        <g clipPath={`url(#${clipPathId})`}>
          <WindowReferenceArea timestep={timestep} timestamp={timestamp} />
        </g>
      </ChartContainer>
      {sources.map((source) => (
        <OverviewLine
          key={source.id}
          id={source.id}
          width={width}
          height={height}
          field={source.field}
          unit={source.unit}
          color={source.color}
          data={source.overview!}
          startTime={playbackBounds.startTime}
          endTime={playbackBounds.endTime}
        />
      ))}
    </Box>
  );
}

const OverviewLine = React.memo(function OverviewLine({
  id,
  width,
  height,
  field,
  unit,
  color,
  data,
  startTime,
  endTime,
}: {
  id: string;
  width: number;
  height: number;
  field: string;
  unit: string | null;
  color: string;
  data: ReadonlyArray<PlayerRecord<"default">>;
} & TimeRange) {
  const clipPathId = `${id}-clip-path`;

  const chartProps: StrictOmit<ChartContainerProps, "width" | "height"> =
    useMemo(() => {
      const seriesData = prepareSeriesData(
        id,
        field,
        unit,
        color,
        data,
        startTime,
      );

      return {
        ...commonOverviewChartProps,
        xAxis: [
          {
            data: seriesData.timestamps,
            min: 0,
            max: Number(endTime - startTime),
          },
        ],
        series: [
          {
            type: "line",
            curve: "linear",
            label: getLabel(field, unit),
            color,
            data: seriesData.values,
          },
        ],
        children: (
          <>
            <ChartsClipPath id={clipPathId} />
            <g clipPath={`url(#${clipPathId})`}>
              <LinePlot
                slots={{
                  line: Path,
                }}
              />
            </g>
          </>
        ),
      };
    }, [id, color, data, endTime, field, startTime, unit, clipPathId]);

  return (
    <ChartContainer
      sx={{ position: "absolute", top: 0, left: 0 }}
      width={width}
      height={height}
      {...chartProps}
    />
  );
});

function WindowChart({
  width,
  height,
  timestep,
  timestamp,
  sources,
}: {
  width: number;
  height: number;
  timestep: TimestepValue;
  timestamp: bigint;
} & Pick<ChartRecords, "timestamp" | "sources">) {
  const id = useId();
  const clipPathId = `${id}-clip-path`;

  const windowSize = getWindowSizeForTimestep(timestep);

  const { bounds: playbackBounds } = useLoadedPlaybackSource();

  const formatPlaybackTimestamp = useFormatPlaybackTimestamp();

  const recordWindow = calculateRecordWindow(
    windowSize,
    timestamp,
    playbackBounds,
  );

  const domain = {
    min: utcToRelativeNanoseconds(
      recordWindow.startTime,
      playbackBounds.startTime,
    ),
    max: utcToRelativeNanoseconds(
      recordWindow.endTime,
      playbackBounds.startTime,
    ),
  };

  const preparedSeries = sources.map((source) =>
    prepareSeriesData(
      source.id,
      source.field,
      source.unit,
      source.color,
      source.window,
      playbackBounds.startTime,
    ),
  );

  return (
    <ChartContainer
      width={width}
      height={height}
      xAxis={[
        {
          id: "x-primary",
          scaleType: "linear",
          valueFormatter(value) {
            return formatPlaybackTimestamp(value);
          },
          ...domain,
          tickInterval: calculateWindowTicks(
            recordWindow,
            windowSize / 6n,
            windowSize,
            playbackBounds.startTime,
          ),
        },
        ...sources.map((source, index) => ({
          id: `${source.id}-x`,
          scaleType: "linear" as const,
          ...domain,
          data: preparedSeries[index].timestamps,
        })),
      ]}
      series={preparedSeries.map((series) => ({
        type: "line" as const,
        curve: "linear" as const,
        xAxisId: `${series.id}-x`,
        data: series.values,
        color: series.color,
        label: series.label,
      }))}
      margin={{
        right: 10,
        bottom: 30,
        left: 100,
      }}
      skipAnimation
      disableAxisListener
    >
      <ChartsClipPath id={clipPathId} offset={{ top: 5, bottom: 5 }} />
      <ChartsLegend />
      <ChartsXAxis axisId="x-primary" />
      <ChartsYAxis />
      <ChartsGrid vertical horizontal />
      <g clipPath={`url(#${clipPathId})`}>
        <LinePlot
          slots={{
            line: Path,
          }}
        />
        <MarkPlot />
        <ChartsReferenceLine
          x={utcToRelativeNanoseconds(timestamp, playbackBounds.startTime)}
        />
        <PointerPositionIndicator domain={domain} series={preparedSeries} />
      </g>
    </ChartContainer>
  );
}

const AxisHighlight = styled("line")(({ theme }) => ({
  pointerEvents: "none",
  fill: "none",
  stroke: theme.palette.text.primary,
  strokeWidth: 1,
  strokeDasharray: "5 2",
}));

//   MUI's interaction indicators - the axis highlight, line mark highlights,
// and tooltip - don't work as expected when the records change as a result of
// playback time changing. Their internal state is structured in such a way
// that even the tooltip title, the formatted playback timestamp, can get out
// of sync with the record whose values are shown in the tooltip.
//   This component rolls most of the logic itself by tracking and storing the
// pointer position alone and always providing highlights and a tooltip for the
// record closest to the pointer location, even when the underlying records
// change.
function PointerPositionIndicator({
  domain,
  series,
}: {
  domain: { min: number; max: number };
  series: ReadonlyArray<PreparedSeries>;
}) {
  const drawingArea = useDrawingArea();

  const [pointerLocation, setPointerLocation] =
    useState<PointerLocation | null>(null);
  const pointerContainerProps = usePointerLocation({
    trigger: "pointerover",
    onChange: setPointerLocation,
    onInteractionEnd() {
      setPointerLocation(null);
    },
  });

  const formatPlaybackTimestamp = useFormatPlaybackTimestamp();

  const xScale = useXScale<"linear">();
  const yScale = useYScale<"linear">();

  let closestPointIndex: number | null = null;
  if (pointerLocation != null) {
    // The chart may include some records not in the visible domain. This is
    // intentional so part of the line doesn't disappear when a record reaches
    // either end of the domain, but those records outside the domain shouldn't
    // be considered when checking which record the pointer is closest to.
    // TODO: Don't just use the first series. This only works now because all
    //  series come from the same topic and so have matching timestamps.
    const inDomainTimestamps = series[0].timestamps.filter(
      (timestamp) => domain.min <= timestamp && timestamp <= domain.max,
    );

    if (inDomainTimestamps.length > 0) {
      // The x-scale's range is equivalent to the drawing area's coordinates -
      // [left, left + width] - effectively being relative to the parent <svg>,
      // but the pointer location is already relative to the drawing area so
      // it needs to be transformed back into the <svg>'s coordinate frame.
      const value = xScale.invert(pointerLocation.x + drawingArea.left);

      const mostRecentPointIndex = inDomainTimestamps.findLastIndex(
        (timestamp) => timestamp <= value,
      );

      if (mostRecentPointIndex === -1) {
        // Pointer location is before the first record, so the first record
        // is the closest.
        closestPointIndex = 0;
      } else if (mostRecentPointIndex === inDomainTimestamps.length - 1) {
        // Pointer location is at or after the last record, so the last record
        // is the closest.
        closestPointIndex = mostRecentPointIndex;
      } else {
        // The two previous branches implicitly account for there only being one
        // record: if there was only one record, it would be both the earliest
        // (branch 1) and the last (branch 2) and the pointer location would
        // definitely be before, at or after that record. This implies there's
        // at least two records and we need to calculate which one the pointer
        // is closest to.

        const mostRecentPointDistance =
          value - inDomainTimestamps[mostRecentPointIndex];

        const nextPointIndex = mostRecentPointIndex + 1;
        const nextPointDistance = inDomainTimestamps[nextPointIndex] - value;

        if (mostRecentPointDistance <= nextPointDistance) {
          closestPointIndex = mostRecentPointIndex;
        } else {
          closestPointIndex = nextPointIndex;
        }
      }
    }

    if (closestPointIndex != null) {
      // `closestPointIndex` was calculated from the records inside the visible
      // domain, so it needs to be adjusted to be an index into the full set of
      // timestamps. Timestamps are unique so we just need to find the index of
      // the closest timestamp in the full list of records.
      // TODO: Don't just use the first series. This only works now because all
      //  series come from the same topic and so have matching timestamps.
      closestPointIndex = series[0].timestamps.indexOf(
        inDomainTimestamps[closestPointIndex],
      );
    }
  }

  let indicators = null;
  if (closestPointIndex != null) {
    // TODO: Don't just use the first series. This only works now because all
    //  series come from the same topic and so have matching timestamps.
    const timestamp = series[0].timestamps[closestPointIndex];
    const xCoordinate = xScale(timestamp);

    indicators = {
      timestamp,
      xCoordinate,
      points: series.map((currentSeries) => ({
        id: currentSeries.id,
        label: currentSeries.label,
        color: currentSeries.color,
        value: currentSeries.values[closestPointIndex],
      })),
      tooltip: {
        location: {
          width: 0,
          height: 0,
          x: pointerLocation?.clientX ?? 0,
          y: pointerLocation?.clientY ?? 0,
          top: pointerLocation?.clientY ?? 0,
          right: pointerLocation?.clientX ?? 0,
          bottom: pointerLocation?.clientY ?? 0,
          left: pointerLocation?.clientX ?? 0,
          toJSON: () => "",
        },
      },
    };
  }

  return (
    <>
      <rect
        {...pointerContainerProps}
        x={drawingArea.left}
        y={drawingArea.top}
        width={drawingArea.width}
        height={drawingArea.height}
        fill="transparent"
      />
      {closestPointIndex != null && indicators != null && (
        <>
          <AxisHighlight
            x1={indicators.xCoordinate}
            y1={yScale.range()[0]}
            x2={indicators.xCoordinate}
            y2={yScale.range()[1]}
          />
          {indicators.points.map((point) =>
            point.value == null ? null : (
              <LineHighlightElement
                key={point.id}
                id={point.id}
                color={point.color}
                x={indicators.xCoordinate}
                y={yScale(point.value)}
              />
            ),
          )}
          <Popper
            open
            anchorEl={{
              getBoundingClientRect: () => indicators.tooltip.location,
            }}
            placement="right-start"
            className={chartsTooltipClasses.root}
            sx={{
              pointerEvents: "none",
              zIndex: (theme) => theme.zIndex.modal,
            }}
          >
            <ChartsTooltipPaper className={chartsTooltipClasses.paper}>
              <ChartsTooltipTable className={chartsTooltipClasses.table}>
                <thead>
                  <ChartsTooltipRow>
                    <ChartsTooltipCell colSpan={3}>
                      <Typography>
                        {formatPlaybackTimestamp(indicators.timestamp)}
                      </Typography>
                    </ChartsTooltipCell>
                  </ChartsTooltipRow>
                </thead>
                <tbody>
                  {indicators.points.map((point) => (
                    <ChartsTooltipRow
                      key={point.id}
                      className={chartsTooltipClasses.row}
                    >
                      <ChartsTooltipCell
                        className={`${chartsTooltipClasses.cell} ${chartsTooltipClasses.markCell}`}
                      >
                        <ChartsTooltipMark
                          className={chartsTooltipClasses.mark}
                          color={point.color}
                        />
                      </ChartsTooltipCell>
                      <ChartsTooltipCell
                        className={`${chartsTooltipClasses.cell} ${chartsTooltipClasses.labelCell}`}
                      >
                        <Typography>{point.label}</Typography>
                      </ChartsTooltipCell>
                      <ChartsTooltipCell
                        className={`${chartsTooltipClasses.cell} ${chartsTooltipClasses.valueCell}`}
                      >
                        {point.value ?? "-"}
                      </ChartsTooltipCell>
                    </ChartsTooltipRow>
                  ))}
                </tbody>
              </ChartsTooltipTable>
            </ChartsTooltipPaper>
          </Popper>
        </>
      )}
    </>
  );
}

interface PreparedSeries {
  id: string;
  label: string;
  color: string;
  timestamps: Array<number>;
  values: Array<number | null>;
}

function prepareSeriesData(
  id: string,
  field: string,
  unit: string | null,
  color: string,
  records: ReadonlyArray<PlayerRecord<"default">>,
  startTime: bigint,
): PreparedSeries {
  const timestamps = new Array<number>();
  const values = new Array<number | null>();

  for (let i = 0; i < records.length; i++) {
    const record = records[i];

    timestamps.push(utcToRelativeNanoseconds(record.timestamp, startTime));

    const value: unknown = get(record.data, field);

    values.push(typeof value === "number" ? value : null);
  }

  return {
    id,
    label: getLabel(field, unit),
    color,
    timestamps,
    values,
  };
}

function calculateOverviewTicks(
  playerBounds: TimeRange,
  tickCount: number,
): Array<number> {
  const durationNs = Number(playerBounds.endTime - playerBounds.startTime);
  const intervalNs = Math.floor(durationNs / (tickCount + 1));

  const ticks = new Array<number>();

  for (let tick = intervalNs; tick < durationNs; tick += intervalNs) {
    ticks.push(tick);
  }

  return ticks;
}

function getLabel(field: string, unit: string | null): string {
  return unit == null ? field : `${field} (${unit})`;
}
