// eslint-enable no-implicit-coercion
"use client";
import * as amplitude from "@amplitude/analytics-browser";
import {
  Button,
  Card,
  Checkbox,
  Divider,
  Drawer,
  DrawerBody,
  DrawerContent,
  DrawerHeader,
  DrawerOverlay,
  Flex,
  Grid,
  GridItem,
  Stack,
  Text,
  Textarea,
  useDisclosure,
} from "@chakra-ui/react";
import { useAuth } from "@clerk/clerk-react";
import { HamburgerMenuIcon } from "@radix-ui/react-icons";
import { SearchSelect, SearchSelectItem } from "@tremor/react";
import OpenAI from "openai";
import {
  ChatCompletionContentPart,
  ChatCompletionContentPartImage,
} from "openai/resources/index.mjs";
import { useEffect, useReducer, useRef, useState } from "react";
import { useSearchParams } from "react-router-dom";
import { v4 as uuidv4 } from "uuid";
import { ChatInput, ChatMessages } from ".";
import { BaseModelResponse, LoRAResponse } from "../../interfaces/model";
import DashboardContentLayout from "../../layouts/dashboard-content-layout";
import { useApiClient } from "../../services/api/api-client-context";
import ApiClient, { getApi } from "../../services/api/api-service";
import { useGlobalAlert } from "../../services/global-alert/global-alert-context";
import { getBaseApiUrl } from "../../utils";
import { ContactSupportLink } from "../contact-support-link";
import LoadingIndicator from "../loading-indicator";
import { ChatMessage, ChatState } from "./common/types";
import SliderInput from "./ui/slider-input";

type ChatAction<Type extends keyof ChatState> = {
  field: Type;
  value: ChatState[Type];
};

export function PlaygroundModule() {
  const [searchParams, setSearchParams] = useSearchParams();
  const preSelectedModelName = searchParams.get("model");

  const [isLoading, setIsLoading] = useState(false);
  const { addAlert, clear } = useGlobalAlert();
  const [isSessionEnded, setIsSessionEnded] = useState(false);
  const [trigger, setTrigger] = useState(0);
  const { userId } = useAuth();
  let apiClient: ApiClient | undefined;

  let lorasResult: {
    data: LoRAResponse[];
    error: Error | null;
  };
  let baseModelsResult: {
    data: BaseModelResponse[];
    error: Error | null;
  };
  if (userId) {
    // eslint-disable-next-line react-hooks/rules-of-hooks
    apiClient = useApiClient();

    lorasResult = getApi<LoRAResponse[]>("/lora") as {
      data: LoRAResponse[];
      error: Error | null;
    };
    baseModelsResult = getApi<BaseModelResponse[]>("/base-model") as {
      data: BaseModelResponse[];
      error: Error | null;
    };
  } else {
    lorasResult = { data: [], error: null };
    baseModelsResult = {
      data: [
        {
          name: "OpenGVLab/InternVL-Chat-V1-5",
          isVlmModel: true,
          unitPrice: 0,
          toolsEnabled: false,
          activeDeployment: null,
        },
      ],
      error: null,
    };
  }

  const [selectedModel, setSelectedModel] = useState<
    LoRAResponse | BaseModelResponse
  >();
  const [temperature, setTemperature] = useState(0.7);
  const [maxTokens, setMaxTokens] = useState(256);
  const [jsonMode, setJsonMode] = useState(false);
  const [topP, setTopP] = useState(1.0);
  const [systemPrompt, setSystemPrompt] = useState("");
  const { isOpen, onOpen, onClose } = useDisclosure();

  useEffect(() => {
    if (lorasResult.data && baseModelsResult.data) {
      if (!selectedModel) {
        updateSelectedModel(
          preSelectedModelName ??
            (userId
              ? "mistralai/Mixtral-8x7B-Instruct-v0.1"
              : "OpenGVLab/InternVL-Chat-V1-5")
        );
      }
    }
  }, [
    selectedModel,
    lorasResult.data,
    baseModelsResult.data,
    preSelectedModelName,
  ]);

  useEffect(() => {
    setRequestBody({
      field: "messages",
      value: [],
    });
    setIsSessionEnded(false);
  }, [selectedModel]);

  const [requestBody, setRequestBody] = useReducer(
    (state: ChatState, action: ChatAction<keyof ChatState>): ChatState => {
      if (action.field === "stop") {
        return {
          ...state,
          [action.field]: (action.value as string[]).filter(
            (v: string) => v.trim().length > 0
          ),
        };
      }
      return { ...state, [action.field]: action.value };
    },
    {
      messages: [],
      stop: [],
      top_p: 1,
      top_k: 50,
      presence_penalty: 0,
      frequency_penalty: 0,
      context_length_exceeded_behavior: "truncate",
      // eslint-disable-next-line @typescript-eslint/no-non-null-assertion
      // ...model.generationDefaults!,
      temperature: 0,
      max_tokens: 1024,
    }
  );
  const endOfMessagesRef = useRef<HTMLDivElement | null>(null);

  const scrollToBottom = () => {
    endOfMessagesRef.current?.scrollIntoView({ behavior: "smooth" });
  };

  useEffect(() => {
    // Scroll to the bottom of the chat when the messages change, setTimeout is a trick to let browser render.
    setTimeout(scrollToBottom, 0);
  }, [requestBody]); // Dependency array includes messages, so effect runs when messages change

  useEffect(() => {
    if (lorasResult.error || baseModelsResult.error) {
      addAlert({
        type: "error",
        body: (
          <Text>
            Unexpected error happened when loading usage stats, please retry{" "}
            <ContactSupportLink /> if the issue persists.
          </Text>
        ),
      });
    }
  }, [lorasResult.error, baseModelsResult.error, addAlert]);

  if (!lorasResult.data || !baseModelsResult.data) {
    return <LoadingIndicator />;
  }

  const streamFunctionCall = async (
    updatedMessages: ChatMessage[],
    chunk: OpenAI.Chat.Completions.ChatCompletionChunk
  ) => {
    if (!chunk.choices[0].delta.tool_calls) {
      return;
    }

    // Stream the function call.
    for (const delta of chunk.choices[0].delta.tool_calls) {
      if (!delta) {
        continue;
      }
      const toolCalls =
        updatedMessages[updatedMessages.length - 1].toolCalls || [];
      if (delta.index >= toolCalls.length) {
        if (updatedMessages[updatedMessages.length - 1]?.metadata != null) {
          updatedMessages[updatedMessages.length - 1].metadata!.loading = false;
        }
        const toolCall = {
          id: delta.id!,
          type: delta.type!,
          function: {
            name: delta.function!.name!,
            arguments: delta.function!.arguments || "",
          },
        };
        toolCalls.push(toolCall);
      } else {
        toolCalls[delta.index] = {
          ...toolCalls[delta.index],
          function: {
            ...toolCalls[delta.index].function,
            arguments:
              toolCalls[delta.index].function.arguments +
                delta.function!.arguments || "",
          },
        };
      }
      updatedMessages[updatedMessages.length - 1].toolCalls = toolCalls;
      setRequestBody({ field: "messages", value: [...updatedMessages] });
      setTrigger((prev) => prev + 1);
    }
  };

  let lastToolCallId = "";
  const serializeMessages = (
    messages: ChatMessage[]
  ): OpenAI.Chat.Completions.ChatCompletionMessageParam[] => {
    const serializedMessages = messages.flatMap((message) => {
      const serializedMessage: OpenAI.Chat.Completions.ChatCompletionMessageParam[] =
        [];
      if (message.metadata?.functionResponse) {
        serializedMessage.push({
          role: "tool",
          content: message.metadata.functionResponse,
          tool_call_id:
            message.toolCalls?.[0].id ||
            message.toolCalls?.[0].function.name ||
            lastToolCallId,
        });
      }
      if (message.toolCalls) {
        lastToolCallId = message.toolCalls[0].id;
        serializedMessage.push({
          role: "assistant",
          tool_calls: [
            ...message.toolCalls.map(
              (toolCall) =>
                ({
                  id: toolCall.id,
                  type: toolCall.type,
                  function: {
                    name: toolCall.function.name,
                    arguments: toolCall.function.arguments,
                  },
                }) as OpenAI.Chat.Completions.ChatCompletionMessageToolCall
            ),
          ],
        });
      }
      if (message.content && message.toolCalls == null) {
        serializedMessage.push({
          role: message.role,
          content: message.content,
          tool_call_id: message.toolCallId,
        } as OpenAI.Chat.Completions.ChatCompletionMessageParam);
      }
      return serializedMessage;
    });

    return serializedMessages;
  };

  // eslint-disable-next-line complexity
  const fetchChatCompletion = async (
    newMessage: ChatMessage,
    isNested: boolean
  ) => {
    const updatedMessages = requestBody.messages;
    const openai = new OpenAI({
      baseURL: `${getBaseApiUrl()}/v1` + (userId ? "" : "/demo"),
      apiKey: (await apiClient?.authTokenGetter()) ?? "",
      dangerouslyAllowBrowser: true,
    });
    try {
      if (!isNested) {
        clear();
        setIsLoading(true);
      }

      if (newMessage.role !== "tool") {
        updatedMessages.push(newMessage);
        updatedMessages.push({
          id: uuidv4(),
          content: "",
          role: "assistant",
          metadata: {
            loading: true,
          },
        });
      } else {
        updatedMessages.push({
          id: uuidv4(),
          content: "",
          role: "assistant",
          metadata: {
            loading: true,
            functionResponse: newMessage.content,
          },
        });
      }
      setRequestBody({
        field: "messages",
        value: [...updatedMessages],
      });

      const stream = await openai.chat.completions.create({
        model: selectedModel!.name,
        stream: true,
        messages: serializeMessages([
          {
            id: uuidv4(),
            content: systemPrompt,
            role: "system",
          },
          ...updatedMessages,
        ]),
        temperature,
        top_p: topP,
        max_tokens: maxTokens,
        response_format: jsonMode ? { type: "json_object" } : { type: "text" },
      });

      for await (const chunk of stream) {
        if (chunk.choices[0].delta.content) {
          if (updatedMessages[updatedMessages.length - 1]?.metadata != null) {
            updatedMessages[updatedMessages.length - 1].metadata!.loading =
              false;
          }
          updatedMessages[updatedMessages.length - 1].content +=
            chunk.choices[0]?.delta?.content || "";
          setTrigger((prev) => prev + 1);
          setRequestBody({ field: "messages", value: [...updatedMessages] });
        } else if (chunk.choices[0].delta.tool_calls) {
          await streamFunctionCall(updatedMessages, chunk);
        }
      }
    } catch (e) {
      addAlert({
        type: "error",
        body: (
          <Text>
            <b>Error happened when getting response: </b> {String(e)}
          </Text>
        ),
      });
      setRequestBody({ field: "messages", value: [...updatedMessages] });
    }

    if (!isNested) {
      setIsLoading(false);
    }
  };

  function updateSelectedModel(modelName: string) {
    if (!lorasResult.data || !baseModelsResult.data) {
      return;
    }

    let model: LoRAResponse | BaseModelResponse | undefined =
      lorasResult.data.find((m) => m.name === modelName);
    if (!model) {
      model = baseModelsResult.data.find((m) => m.name === modelName);
    }

    if (model) {
      setSelectedModel(model);
      setSearchParams({ model: model.name });
    }
  }

  function renderModelParamCard() {
    return (
      <Card h="100%" p={5} variant="information">
        <Stack gap={3}>
          <Stack gap={3}>
            <Text>
              <b>Model</b>
            </Text>
            <SearchSelect
              className="max-content"
              placeholder="Select a model"
              onValueChange={updateSelectedModel}
              value={selectedModel?.name}
            >
              {baseModelsResult.data.map((data) => (
                <SearchSelectItem
                  title={data.name}
                  key={`base-model-${data.name}`}
                  value={data.name}
                  className="cursor-pointer"
                >
                  Base Model: {data.name}
                </SearchSelectItem>
              ))}
              {lorasResult.data.map((data) => (
                <SearchSelectItem
                  title={data.name}
                  key={`lora-${data.name}`}
                  value={data.name}
                  className="cursor-pointer"
                >
                  LoRA: {data.name}
                </SearchSelectItem>
              ))}
            </SearchSelect>
          </Stack>
          <Divider />
          <Stack>
            <Text>
              <b>Parameters</b>
            </Text>
            {!selectedModel?.isVlmModel && (
              <Checkbox
                isChecked={jsonMode}
                onChange={() => setJsonMode(!jsonMode)}
              >
                JSON mode
              </Checkbox>
            )}
            <SliderInput
              label="Temperature"
              max={2.0}
              min={0.0}
              step={0.1}
              value={temperature}
              setValue={setTemperature}
            />
            <SliderInput
              label="Max output tokens"
              max={1024}
              min={0}
              step={1}
              value={maxTokens}
              setValue={setMaxTokens}
            />
            <SliderInput
              label="Top P"
              max={1.0}
              min={0.0}
              step={0.1}
              value={topP}
              setValue={setTopP}
            />
            <>
              <Text>System prompt</Text>
              <Textarea
                value={systemPrompt}
                onChange={(e) => setSystemPrompt(e.target.value)}
                placeholder="Enter custom System prompt here..."
                size="sm"
                h="150px"
              />
            </>
          </Stack>
        </Stack>
      </Card>
    );
  }

  return (
    <DashboardContentLayout
      subNavSection="Overview"
      mainTitle="Playground"
      mainTitleHelperText={
        <>
          <Text color="gray.600" fontSize={{ base: "md", md: "sm" }} pt={2}>
            Run quick test and play around with base models and LoRAs.
          </Text>
        </>
      }
    >
      <Grid gap={2} templateColumns="repeat(4, 1fr)">
        <GridItem colSpan={{ base: 4, md: 3 }}>
          <Card
            p={5}
            variant="information"
            height={{ base: "85vh", md: "75vh" }}
          >
            <Flex
              justifyContent="end"
              w="full"
              pb={4}
              display={{ base: "flex", md: "none" }}
            >
              <Button onClick={onOpen}>
                <HamburgerMenuIcon />
              </Button>
            </Flex>
            <ChatMessages
              messages={requestBody.messages}
              isLoading={isLoading}
              trigger={trigger}
            />
            <Stack>
              <ChatInput
                selectedModelName={selectedModel?.name ?? ""}
                isVlmModel={selectedModel?.isVlmModel ?? false}
                onSubmit={async (text, images?: string[]) => {
                  amplitude.track("Send Chat Message");
                  let content: string | ChatCompletionContentPart[] = text;
                  if (selectedModel?.isVlmModel) {
                    content = [
                      {
                        type: "text",
                        text: text,
                      },
                      ...(images ?? []).map(
                        (image) =>
                          ({
                            type: "image_url",
                            image_url: {
                              url: image,
                            },
                          }) as ChatCompletionContentPartImage
                      ),
                    ];
                  }
                  await fetchChatCompletion(
                    {
                      id: uuidv4(),
                      // @ts-expect-error hack to make the type work
                      content: content,
                      role: "user",
                    },
                    false
                  );
                }}
                multiModal={false}
                isLoading={isLoading || !selectedModel}
                isSessionEnded={isSessionEnded}
                onReset={() => {
                  setRequestBody({
                    field: "messages",
                    value: [],
                  });
                  setIsSessionEnded(false);
                }}
              />
            </Stack>
          </Card>
        </GridItem>
        <GridItem colSpan={1} display={{ base: "none", md: "block" }}>
          {renderModelParamCard()}
        </GridItem>
      </Grid>
      <Drawer placement="right" onClose={onClose} isOpen={isOpen}>
        <DrawerOverlay />
        <DrawerContent>
          <DrawerHeader borderBottomWidth="1px">Basic Drawer</DrawerHeader>
          <DrawerBody>{renderModelParamCard()}</DrawerBody>
        </DrawerContent>
      </Drawer>
    </DashboardContentLayout>
  );
}
