import {
  Box,
  Button,
  Flex,
  Select,
  Skeleton,
  Stack,
  Table,
  Tbody,
  Td,
  Text,
  Th,
  Thead,
  Tr,
} from "@chakra-ui/react";
import {
  ColumnDef,
  ColumnSort,
  flexRender,
  getCoreRowModel,
  getFilteredRowModel,
  getPaginationRowModel,
  getSortedRowModel,
  Row,
  SortingState,
  useReactTable,
} from "@tanstack/react-table";
import * as React from "react";
import { FaArrowDownLong, FaArrowUpLong } from "react-icons/fa6";
import {
  MdKeyboardArrowLeft,
  MdKeyboardArrowRight,
  MdKeyboardDoubleArrowLeft,
  MdKeyboardDoubleArrowRight,
} from "react-icons/md";
import { formatNumber } from "../utils";

export type DataTableProps<Data extends object> = {
  data: Data[] | undefined;
  isLoading: boolean;
  columns: ColumnDef<Data, any>[];
  defaultSorting?: ColumnSort[];
  onRowClick?: (row: Row<Data>) => void;
  emptyView?: React.ReactElement;
  expandedRowRenderer?: (row: Row<Data>) => React.ReactElement;
  totalRows?: number;
};

export function DataTable<Data extends object>({
  data,
  isLoading,
  columns,
  defaultSorting,
  emptyView,
  onRowClick,
  expandedRowRenderer,
  totalRows,
}: DataTableProps<Data>) {
  const [sorting, setSorting] = React.useState<SortingState>(
    defaultSorting ?? []
  );

  const table = useReactTable({
    columns,
    data: data ?? [],
    getCoreRowModel: getCoreRowModel(),
    onSortingChange: setSorting,
    getSortedRowModel: getSortedRowModel(),
    getFilteredRowModel: getFilteredRowModel(),
    getPaginationRowModel: getPaginationRowModel(),
    state: {
      sorting,
    },
    initialState: {
      pagination: {
        pageSize: 50,
      },
    },
  });

  if (isLoading || !data) {
    return (
      <Stack>
        <Skeleton height="40px" />
        {[...Array(10)].map((_, idx) => (
          <Skeleton height="50px" key={idx} />
        ))}
      </Stack>
    );
  }

  const totalPages = Math.ceil(
    data.length / table.getState().pagination.pageSize
  );

  return (
    <>
      <Table>
        <Thead bgColor="highlightBg">
          {table.getHeaderGroups().map((headerGroup) => (
            <Tr key={headerGroup.id}>
              {headerGroup.headers.map((header) => {
                const meta: any = header.column.columnDef.meta;
                return (
                  <Th
                    key={header.id}
                    onClick={header.column.getToggleSortingHandler()}
                    isNumeric={meta?.isNumeric}
                  >
                    <Flex align="center" justifyContent="left" cursor="pointer">
                      {flexRender(
                        header.column.columnDef.header,
                        header.getContext()
                      )}
                      <Box ml={2}>
                        {header.column.getIsSorted() ? (
                          header.column.getIsSorted() === "desc" ? (
                            <FaArrowDownLong aria-label="sorted descending" />
                          ) : (
                            <FaArrowUpLong aria-label="sorted ascending" />
                          )
                        ) : null}
                      </Box>
                    </Flex>
                  </Th>
                );
              })}
            </Tr>
          ))}
        </Thead>
        <Tbody>
          {table.getRowModel().rows.flatMap((row) => (
            <>
              <Tr
                key={row.id}
                fontSize="sm"
                cursor={onRowClick ? "pointer" : "auto"}
                _hover={{
                  bgColor: onRowClick ? "gray.100" : "white",
                  cursor: onRowClick ? "pointer" : "auto",
                }}
              >
                {row.getVisibleCells().map((cell) => {
                  const meta: any = cell.column.columnDef.meta;
                  return (
                    <Td
                      key={cell.id}
                      isNumeric={meta?.isNumeric}
                      onClick={() => {
                        if (onRowClick && !meta?.isActionColumn) {
                          onRowClick(row);
                        }
                      }}
                      borderBottomWidth={row.getIsExpanded() ? 0 : 1}
                    >
                      {flexRender(
                        cell.column.columnDef.cell,
                        cell.getContext()
                      )}
                    </Td>
                  );
                })}
              </Tr>
              {expandedRowRenderer && row.getIsExpanded() && (
                <Tr key={`${row.id}-expanded`} borderBottomWidth={1}>
                  <Td colSpan={row.getVisibleCells().length}>
                    {expandedRowRenderer(row)}
                  </Td>
                </Tr>
              )}
            </>
          ))}
        </Tbody>
      </Table>
      {data.length === 0 && emptyView}
      {data.length > 0 && (
        <Flex justify="space-between" textColor="gray.600" fontSize="sm" pt={5}>
          <Flex align="center">
            Showing{" "}
            {table.getState().pagination.pageIndex *
              table.getState().pagination.pageSize +
              1}{" "}
            to{" "}
            {table.getState().pagination.pageIndex *
              table.getState().pagination.pageSize +
              table.getRowModel().rows.length}{" "}
            of {formatNumber(totalRows ?? data.length)} results
          </Flex>
          <Flex>
            <Button
              size="sm"
              variant="outline"
              onClick={() => table.setPageIndex(0)}
              isDisabled={!table.getCanPreviousPage()}
            >
              <MdKeyboardDoubleArrowLeft />
            </Button>
            <Button
              size="sm"
              variant="outline"
              onClick={() => table.previousPage()}
              isDisabled={!table.getCanPreviousPage()}
            >
              <MdKeyboardArrowLeft />
            </Button>
            <Flex align="center" pl={1} pr={1} className="gap-1">
              <Text>Page: </Text>
              <Select
                w="auto"
                size="sm"
                borderRadius={5}
                value={table.getState().pagination.pageIndex}
                onChange={(e) => {
                  table.setPageIndex(Number(e.target.value));
                }}
              >
                {[...Array(totalPages).keys()].map((_, idx) => (
                  <option key={idx} value={idx}>
                    {idx + 1}
                  </option>
                ))}
              </Select>
            </Flex>
            <Button
              size="sm"
              variant="outline"
              onClick={() => table.nextPage()}
              isDisabled={!table.getCanNextPage()}
            >
              <MdKeyboardArrowRight />
            </Button>
            <Button
              size="sm"
              variant="outline"
              onClick={() => table.setPageIndex(table.getPageCount() - 1)}
              isDisabled={!table.getCanNextPage()}
            >
              <MdKeyboardDoubleArrowRight />
            </Button>
          </Flex>
          <Flex align="center" className="gap-1">
            <Text>Page size:</Text>
            <Select
              w="auto"
              size="sm"
              borderRadius={5}
              value={table.getState().pagination.pageSize}
              onChange={(e) => {
                table.setPageSize(Number(e.target.value));
              }}
            >
              {[10, 20, 30, 50, "all"].map((pageSize) =>
                pageSize === "all" ? (
                  <option key="all" value={data.length}>
                    all
                  </option>
                ) : (
                  <option key={pageSize} value={pageSize}>
                    {pageSize}
                  </option>
                )
              )}
            </Select>
          </Flex>
        </Flex>
      )}
    </>
  );
}
