import { Container, Typography, useTheme } from "@mui/material";
import { Alert, Box, Loader } from "components";
import { PageLayout } from "modules";
import PropTypes from "prop-types";
import React, { useEffect } from "react";
import ImageInference from "views/Inference/ImageInference";
import TextInference from "views/Inference/TextInference";

import packageJson from "../../../package.json";

// object_detection
// image_classification;
// text_classification;

const IMAGE_MODELS_CODES = ["object_detection", "image_classification"];
const TEXT_MODELS_CODES = ["text_classification"];

function Inference({ getTrainingStates, handleCreateQuery }) {
  const theme = useTheme();

  if (getTrainingStates?.isError)
    return (
      <Alert
        alertType="error"
        toggle={getTrainingStates.isError}
        handleCloseError={getTrainingStates.refetch}
        duration={0}
      >
        {getTrainingStates.error?.message}
      </Alert>
    );

  return getTrainingStates.status === "fulfilled" ? (
    <PageLayout>
      <Container
        maxWidth="xl"
        sx={{
          minHeight: "90vh",
          background: `linear-gradient(135deg, ${theme.palette.background.paper} 0%, ${theme.palette.background.default} 100%)`,
          py: 4,
        }}
      >
        <Box
          sx={{
            display: "flex",
            flexDirection: "column",
            alignItems: "center",
            maxWidth: "800px",
            mx: "auto",
            px: 2,
          }}
        >
          <Box
            sx={{
              textAlign: "center",
              mb: 6,
              pt: 4,
            }}
          >
            <Typography
              variant="h3"
              component="h1"
              sx={{
                fontWeight: 700,
                background: (_theme) =>
                  `linear-gradient(135deg, ${_theme.palette.primary.main}, ${_theme.palette.secondary.main})`,
                backgroundClip: "text",
                WebkitBackgroundClip: "text",
                color: "transparent",
                mb: 2,
              }}
            >
              {getTrainingStates.data.training.project ||
                (IMAGE_MODELS_CODES.includes(getTrainingStates.data.training?.model_type.code)
                  ? "AI Image Inference"
                  : "AI Text Inference")}
            </Typography>
            <Typography
              variant="h6"
              color="text.secondary"
              sx={{
                maxWidth: "600px",
                mx: "auto",
                lineHeight: 1.6,
                fontWeight: 400,
              }}
            >
              {IMAGE_MODELS_CODES.includes(getTrainingStates.data.training?.model_type.code)
                ? `Analysis powered by our advanced ${getTrainingStates.data.training?.model_type.model_type} model`
                : `Analysis powered by our advanced ${getTrainingStates.data.training?.model_type.model_type} model`}
            </Typography>
          </Box>

          <Box sx={{ width: "100%" }}>
            {IMAGE_MODELS_CODES.includes(getTrainingStates.data.training?.model_type.code) ? (
              <ImageInference
                handleCreateQuery={handleCreateQuery}
                getTrainingStates={getTrainingStates}
              />
            ) : (
              <TextInference handleCreateQuery={handleCreateQuery} />
            )}
          </Box>
        </Box>
        <Typography
          variant="caption"
          color="text.secondary"
          sx={{
            position: "absolute",
            top: 10,
            right: 15,
          }}
        >
          Version {packageJson.version}
        </Typography>
      </Container>
    </PageLayout>
  ) : (
    <Loader />
  );
}

Inference.propTypes = {
  getTrainingStates: PropTypes.object.isRequired,
  handleCreateQuery: PropTypes.func.isRequired,
};

export default Inference;
