import {
  Card,
  CardBody,
  CardHeader,
  SimpleGrid,
  Skeleton,
  Stack,
  Stat,
  StatGroup,
  StatLabel,
  StatNumber,
  Text,
} from "@chakra-ui/react";
import { LineChart } from "@tremor/react";
import "chart.js/auto";
import moment from "moment";
import { useEffect } from "react";
import { getApi } from "../../services/api/api-service";
import { useGlobalAlert } from "../../services/global-alert/global-alert-context";
import {
  formatNumber,
  getCurrentTimezone,
  isHalfHourTimezone,
} from "../../utils";
import { ContactSupportLink } from "../contact-support-link";

type Props = {
  // start and end timestamp, in seconds.
  startTimestamp: number;
  endTimestamp: number;
  modelName?: string;
  apiKeyName?: string;
  granularity: "HOUR" | "DAY";
};

const valueFormatter = (number: number) =>
  `${new Intl.NumberFormat("us").format(number).toString()}`;

function generateAllRequestTimeBetweenDates(
  startDate: Date,
  endDate: Date,
  granularity: "HOUR" | "DAY"
): string[] {
  const dateArray: string[] = [];
  const currentMoment = moment(startDate);

  const endMoment = moment(endDate);

  // Define the format based on granularity
  let format = "YYYY-MM-DD";

  if (granularity === "HOUR") {
    if (isHalfHourTimezone(getCurrentTimezone())) {
      format = "YYYY-MM-DDTHH:30:00";
    } else {
      format = "YYYY-MM-DDTHH:00:00";
    }
  }

  while (currentMoment <= endMoment) {
    // Push the formatted date string based on granularity
    dateArray.push(currentMoment.format(format));

    // Increment currentMoment by the specified granularity
    if (granularity === "HOUR") {
      currentMoment.add(1, "hours");
    } else if (granularity === "DAY") {
      currentMoment.add(1, "days");
    } else if (granularity === "WEEK") {
      currentMoment.add(1, "weeks");
    }
  }

  return dateArray;
}

export default function StatsDashboard({
  startTimestamp,
  endTimestamp,
  modelName,
  apiKeyName,
  granularity,
}: Props) {
  const queryParams = {
    startTimestamp,
    endTimestamp,
    modelName,
    apiKeyName,
    granularity,
    timezone: getCurrentTimezone(),
  };
  const {
    data: queryResult,
    isLoading,
    error,
  } = getApi<UsageStatsResponse>("/usage-stats", queryParams);

  const { addAlert } = useGlobalAlert();
  useEffect(() => {
    if (error) {
      addAlert({
        type: "error",
        body: (
          <Text>
            Unexpected error happened when loading usage stats, please retry and{" "}
            <ContactSupportLink /> if the issue persists.
          </Text>
        ),
      });
    }
  }, [error, addAlert]);

  if (isLoading || !queryResult) {
    return (
      <SimpleGrid minChildWidth={"550px"} spacing={5}>
        <Card>
          <Skeleton height="450px"></Skeleton>
        </Card>
        <Card>
          <Skeleton height="450px"></Skeleton>
        </Card>
      </SimpleGrid>
    );
  }

  const labels = generateAllRequestTimeBetweenDates(
    new Date(startTimestamp * 1000),
    new Date(endTimestamp * 1000),
    granularity
  );

  const numInputTokens = labels.map((label) => {
    const rows = queryResult?.rows.filter((row) => row.requestTime === label);
    return rows?.reduce((acc, row) => acc + row.numInputTokens, 0);
  });
  const numOutputTokens = labels.map((label) => {
    const rows = queryResult?.rows.filter((row) => row.requestTime === label);
    return rows?.reduce((acc, row) => acc + row.numOutputTokens, 0);
  });
  const numTotalTokens = labels.map((label) => {
    const rows = queryResult?.rows.filter((row) => row.requestTime === label);
    return rows?.reduce(
      (acc, row) => acc + row.numOutputTokens + row.numInputTokens,
      0
    );
  });
  const numRequests = labels.map((label) => {
    const rows = queryResult?.rows.filter((row) => row.requestTime === label);
    return rows?.reduce((acc, row) => acc + row.numRequests, 0);
  });

  const maxTotalToken = Math.max(...numTotalTokens);
  const maxNumRequest = Math.max(...numRequests);

  const numTokensChartData = labels.map((label, index) => {
    return {
      name: label,
      "Total Tokens": numTotalTokens[index],
      "Input Tokens": numInputTokens[index],
      "Output Tokens": numOutputTokens[index],
    };
  });
  const numRequestsChartData = labels.map((label, index) => {
    return {
      name: label,
      Requests: numRequests[index],
    };
  });
  return (
    <SimpleGrid minChildWidth={"550px"} spacing={5}>
      <Card variant="information">
        <CardHeader>Tokens</CardHeader>
        <CardBody>
          <StatGroup>
            <Stat>
              <StatLabel>Total</StatLabel>
              <StatNumber>
                {formatNumber(
                  numTotalTokens.reduce((sum, current) => sum + current, 0)
                )}
              </StatNumber>
            </Stat>
            <Stat>
              <StatLabel>Input</StatLabel>
              <StatNumber>
                {formatNumber(
                  numInputTokens.reduce((sum, current) => sum + current, 0)
                )}
              </StatNumber>
            </Stat>
            <Stat>
              <StatLabel>Output</StatLabel>
              <StatNumber>
                {formatNumber(
                  numOutputTokens.reduce((sum, current) => sum + current, 0)
                )}
              </StatNumber>
            </Stat>
          </StatGroup>
          <LineChart
            // className="mt-6"
            data={numTokensChartData}
            index="name"
            categories={["Total Tokens", "Input Tokens", "Output Tokens"]}
            colors={["#596aff", "teal", "amber", "rose", "indigo", "emerald"]}
            valueFormatter={valueFormatter}
            yAxisWidth={
              maxTotalToken <= 100 ? 30 : maxTotalToken.toString().length * 10
            }
          />
        </CardBody>
      </Card>
      <Card variant="information">
        <CardHeader>Requests</CardHeader>
        <CardBody>
          <Stack spacing={3}>
            <StatGroup>
              <Stat>
                <StatLabel>Total</StatLabel>
                <StatNumber>
                  {formatNumber(
                    numRequests.reduce((sum, current) => sum + current, 0)
                  )}
                </StatNumber>
              </Stat>
            </StatGroup>
            <LineChart
              // className="mt-6"
              data={numRequestsChartData}
              index="name"
              categories={["Requests"]}
              colors={["#596aff", "teal", "amber", "rose", "indigo", "emerald"]}
              valueFormatter={valueFormatter}
              yAxisWidth={
                maxNumRequest <= 100 ? 30 : maxNumRequest.toString().length * 10
              }
            />
          </Stack>
        </CardBody>
      </Card>
    </SimpleGrid>
  );
}
