import React, { CSSProperties, useMemo, useState } from "react";
import {
  Box,
  Paper,
  Table,
  TableBody,
  TableCell,
  TableContainer,
  TableFooter,
  TableHead,
  TableRow,
  TableSortLabel,
} from "@mui/material";
import "./DataTable.scss";

type SortType<E extends object> = true | ((obj: E) => string | number);
type SortDirection = "asc" | "desc";
interface ColumnType<E extends object> {
  title: string;
  property: string;
  render: (obj: E, index: number) => React.ReactNode;
  sortFunc?: SortType<E>;
  defaultSortDirection?: SortDirection; // Sorts descending by default
  visible?: boolean;
}
interface RowData {
  selected?: boolean;
  style?: CSSProperties;
  className?: string;
}
interface DataTableProps<E extends object> {
  columns: ColumnType<E>[];
  data: E[];
  getKey?: (obj: E) => string | number;
  defaultSort?: {
    property: string;
    direction?: SortDirection;
  };
  style?: CSSProperties;
  rowDisplay?: (obj: E, index: number) => RowData;
  keysAtBottom?: (string | number)[];
  onClickCell?: (obj: E, index: number, column: string) => void;
  footer?: React.ReactNode;
  fallbackContent?: React.ReactElement;
}

export function DataTable<E extends object>(props: DataTableProps<E>) {
  const [sortProperty, setSortProperty] = useState<string | null>(
    props.defaultSort?.property || null
  );
  const [sortDirection, setSortDirection] = useState<SortDirection>(
    props.defaultSort?.direction || "asc"
  );

  const columns = props.columns.filter((col) => col.visible !== false);

  const sortingProperty = useMemo(() => {
    if (sortProperty === null) return null;
    const index = columns.findIndex((col) => col.property === sortProperty);
    if (index < 0) return null;
    return index;
  }, [columns, sortProperty]);

  const resultData = useMemo(() => {
    let result = [...props.data];
    if (sortingProperty !== null) {
      const keysAtBottom = new Set(props.keysAtBottom || []);
      let rowsAtBottom: E[] = [];
      if (props.getKey) {
        for (let i = 0; i < result.length; i++) {
          const item = result[i];
          if (keysAtBottom.has(props.getKey(item))) {
            rowsAtBottom.push(item);
            result.splice(i, 1);
            i--;
          }
        }
      }

      const sortingColumn = columns[sortingProperty];
      if (sortingColumn !== undefined) {
        const getSortItemFunc =
          sortingColumn.sortFunc === true
            ? (obj: E, index: number) => {
                const value = sortingColumn.render(obj, index);
                if (value || value === 0) return value;
                return "";
              }
            : (obj: E, index: number) => {
                const value = (
                  sortingColumn.sortFunc as (
                    obj: E,
                    index: number
                  ) => string | number
                )(obj, index);
                if (value || value === 0) return value;
                return "";
              };
        result = result.sort((item1, item2) => {
          const sort1 = getSortItemFunc(item1, 0);
          const sort2 = getSortItemFunc(item2, 0);
          if (sort1 === sort2) return 0;
          return sort1 < sort2 ? -1 : 1;
        });
      }

      if (sortDirection === "desc") result = result.reverse();

      result = [...result, ...rowsAtBottom];
    }
    return result;
  }, [props.data, sortingProperty, sortDirection]);

  if (resultData.length === 0 && props.fallbackContent)
    return props.fallbackContent;

  return (
    <TableContainer
      component={Paper}
      style={{
        maxHeight: "max(100vh - 120px, 400px)",
        ...props.style,
      }}
    >
      <Table size="small" stickyHeader>
        <TableHead>
          <TableRow>
            {columns.map((col, columnIndex) => {
              let innerPart: React.ReactNode = col.title;
              if (col.sortFunc) {
                const isSorted = sortingProperty === columnIndex;
                innerPart = (
                  <TableSortLabel
                    active={isSorted}
                    direction={sortDirection}
                    onClick={() => {
                      setSortDirection(
                        isSorted
                          ? sortDirection === "asc"
                            ? "desc"
                            : "asc"
                          : col.defaultSortDirection ||
                              props.defaultSort?.direction ||
                              "asc"
                      );
                      setSortProperty(col.property);
                    }}
                  >
                    {col.title}
                    {isSorted ? (
                      <Box
                        component="span"
                        sx={{
                          border: 0,
                          clip: "rect(0 0 0 0)",
                          height: 1,
                          margin: -1,
                          overflow: "hidden",
                          padding: 0,
                          position: "absolute",
                          top: 20,
                          width: 1,
                        }}
                      >
                        {sortDirection === "desc"
                          ? "sorted descending"
                          : "sorted ascending"}
                      </Box>
                    ) : null}
                  </TableSortLabel>
                );
              }
              return (
                <TableCell
                  key={col.property}
                  padding="normal"
                  style={{ fontSize: 16, whiteSpace: "nowrap" }}
                >
                  {innerPart}
                </TableCell>
              );
            })}
          </TableRow>
        </TableHead>
        <TableBody>
          {resultData.map((data, index) => {
            const key = props.getKey ? props.getKey(data) : index;
            const rowData: RowData = props.rowDisplay?.(data, index) || {};
            return (
              <TableRow
                key={key}
                hover
                className={`table-row-${index % 2 === 0 ? "even" : "odd"} ${
                  rowData.className || ""
                }`}
                style={rowData.style}
                selected={rowData.selected}
              >
                {columns.map((col) => (
                  <TableCell
                    key={col.property}
                    style={{ whiteSpace: "nowrap" }}
                    onClick={() => {
                      if (props.onClickCell) {
                        props.onClickCell(data, index, col.property);
                      }
                    }}
                  >
                    {col.render(data, index)}
                  </TableCell>
                ))}
              </TableRow>
            );
          })}
        </TableBody>
        {!!props.footer && (
          <TableFooter>
            <TableRow>
              <TableCell
                colSpan={columns.length}
                style={{ textAlign: "center" }}
              >
                {props.footer}
              </TableCell>
            </TableRow>
          </TableFooter>
        )}
      </Table>
    </TableContainer>
  );
}
