import { MutableRefObject, useEffect, useRef, useState } from 'react';
import * as deeplab from '@tensorflow-models/deeplab';
import { DeepLabOutput } from '@tensorflow-models/deeplab/dist/types';
import '@tensorflow/tfjs-core';
import '@tensorflow/tfjs-converter';
import '@tensorflow/tfjs-backend-webgl';
import '@tensorflow/tfjs-backend-cpu';
import { handleRemoveBgApi } from 'src/api/useRemoveBgApi';

async function loadModel() {
  return await deeplab.load({ base: 'pascal', quantizationBytes: 2 });
}

async function predict(model: deeplab.SemanticSegmentation, image: HTMLImageElement) {
  const prediction = await model.segment(image);
  const result = renderPrediction(image, prediction);
  return result;
}

function removeColor(color: number[], originalImage: HTMLImageElement, mask: ImageData, width: number, height: number) {
  const originalCanvas = document.createElement('canvas');
  const maskCanvas = document.createElement('canvas');
  const ctx = originalCanvas.getContext('2d');
  const maskCtx = maskCanvas.getContext('2d');

  ctx!.imageSmoothingEnabled = false;
  maskCtx!.imageSmoothingEnabled = false;

  originalCanvas.width = width;
  originalCanvas.height = height;
  maskCanvas.width = width;
  maskCanvas.height = height;

  ctx!.drawImage(originalImage, 0, 0, width, height);
  maskCtx!.putImageData(mask, 0, 0);

  const canvasData = ctx!.getImageData(0, 0, width, height);
  const pix = canvasData.data;
  const maskCanvasData = maskCtx!.getImageData(0, 0, width, height);
  const maskPix = maskCanvasData.data;

  for (let i = 0, n = maskPix.length; i < n; i += 4) {
    if (maskPix[i] === color[0] && maskPix[i + 1] === color[1] && maskPix[i + 2] === color[2]) {
      for (let ind = 0; ind < 12; ind++) {
        pix[i + 3 + ind] = 0;
      }
    }
  }

  ctx!.putImageData(canvasData, 0, 0);
  return originalCanvas.toDataURL();
}

function renderPrediction(image: HTMLImageElement, prediction: DeepLabOutput) {
  const { height, width, segmentationMap } = prediction;

  const segmentationMapData = new ImageData(segmentationMap, width, height);
  const imageWithNoBackground = removeColor([0, 0, 0], image, segmentationMapData, width, height);
  return imageWithNoBackground;
}

const useRemoveImageBackground = () => {
  const model = useRef() as MutableRefObject<deeplab.SemanticSegmentation>;
  const [modelLoaded, setModelLoaded] = useState(false);

  function base64ToBlob(base64: any, mime: any) {
    const byteString = atob(base64.split(',')[1]);
    const ab = new ArrayBuffer(byteString.length);
    const ia = new Uint8Array(ab);
    for (let i = 0; i < byteString.length; i++) {
      ia[i] = byteString.charCodeAt(i);
    }
    return new Blob([ab], { type: mime });
  }

  const arrayBufferToBase64 = (buffer: any) => {
    let binary = '';
    const bytes = new Uint8Array(buffer);
    const length = bytes.byteLength;

    for (let i = 0; i < length; i++) {
      binary += String.fromCharCode(bytes[i]);
    }

    return `data:image/png;base64,${window.btoa(binary)}`;
  };

  const autoRemoveBackground = async (image: HTMLImageElement) => {
    const base64Image = image.src;
    const mimeType = base64Image.split(',')[0].split(':')[1].split(';')[0];
    const imageBlob = base64ToBlob(base64Image, mimeType);

    const formData = new FormData();
    formData.append('image', imageBlob);
    formData.append('size', 'auto');

    try {
      const response = await handleRemoveBgApi(formData);
      const resultBase64 = arrayBufferToBase64(response.data);

      //this one will need to reduce size also
      const img = new Image();
      img.src = resultBase64;

      const resizeImage = (base64Image: any) => {
        return new Promise<any>((resolve, reject) => {
          img.onload = () => {
            // Create a canvas to resize the image
            const canvas = document.createElement('canvas');
            const ctx = canvas.getContext('2d');

            const maxWidth = 800; // set max width for the resized image
            const maxHeight = 800; // set max height for the resized image

            // Calculate the new dimensions while maintaining the aspect ratio
            let width = img.width;
            let height = img.height;

            if (width > maxWidth || height > maxHeight) {
              const aspectRatio = width / height;
              if (width > height) {
                width = maxWidth;
                height = maxWidth / aspectRatio;
              } else {
                height = maxHeight;
                width = maxHeight * aspectRatio;
              }
            }

            // Resize the image on the canvas
            canvas.width = width;
            canvas.height = height;
            ctx?.drawImage(img, 0, 0, width, height);

            // Get the resized image as a base64 string
            const resizedBase64 = canvas.toDataURL();

            // Resolve the promise with the resized image data URL (base64)
            resolve(resizedBase64);
          };

          img.onerror = (error) => {
            reject(error);
          };
        });
      };

      // Call the resize function
      const resizedBase64 = await resizeImage(resultBase64);

      // Return the resized base64 image
      return resizedBase64;
    } catch (error) {
      return undefined;
    }
  };

  useEffect(() => {
    loadModel().then((_model) => {
      model.current = _model;
      setModelLoaded(true);
    });
  }, []);

  return {
    autoRemoveBackground,
    modelLoaded,
  };
};

export default useRemoveImageBackground;
