import JSZip from 'jszip'
import { nonMaximumSuppression } from '@/non-maximum-suppression'

const ort = require('onnxruntime-web')

class Inference {
  static SUPPORTED_MODEL_TYPES = ['detection']

  constructor() {
    this.modelType = null
    this.classLabels = null
    this.inputWidth = null
    this.inputHeight = null
    this.inputChannels = null
    this.heatmapScaling = null
    this.inferenceSession = null
  }

  async run(img, { confidenceThreshold = 0.3, withRotation = true }) {
    if (!this.inferenceSession) {
      throw new Error('Inference session is not initialized.')
    }

    const width = this.inputWidth
    const height = this.inputHeight
    const channels = this.inputChannels

    // create a canvas to resize the input image
    const canvas = document.createElement('canvas')
    canvas.setAttribute('width', width)
    canvas.setAttribute('height', height)
    const ctx = canvas.getContext('2d')
    ctx.drawImage(img, 0, 0, width, height)
    const imageData = ctx.getImageData(0, 0, width, height)

    // prepare input image (image is in HWC format, model input is in NCHW format)
    const imageArray = new Float32Array(channels * height * width)
    const nPixels = height * width
    for (let i = 0; i < nPixels; i++) {
      let r = imageData.data[4 * i]
      let g = imageData.data[4 * i + 1]
      let b = imageData.data[4 * i + 2]

      imageArray[i] = r / 127.5 - 1
      imageArray[1 * nPixels + i] = g / 127.5 - 1
      imageArray[2 * nPixels + i] = b / 127.5 - 1
    }

    // create input tensor from image float32 array
    const imageTensor = new ort.Tensor('float32', imageArray, [1, channels, height, width])

    // feed inputs and run
    const outputMaps = await this.inferenceSession.run({ input: imageTensor })

    // outputs are in NCHW format as well
    const classMap = outputMaps.detection_class_map
    const sizeMap = outputMaps.detection_size_map
    const offsetMap = outputMaps.detection_offset_map

    const nClasses = classMap.dims[1]
    // const mapHeight = classMap.dims[2]
    const mapWidth = classMap.dims[3]

    const heatmapScaling = this.heatmapScaling
    const scalingX = 1
    const scalingY = 1
    const rotationRegression = sizeMap.dims[1] === 4 && withRotation

    let detections = []
    const nDetections = classMap.dims[2] * classMap.dims[3]
    for (let i = 0; i < nDetections; i++) {
      let label = 0
      let confidence = 0
      for (let j = 0; j < nClasses; j++) {
        const conf = classMap.data[j * nDetections + i]
        if (conf >= confidence) {
          label = j
          confidence = conf
        }
      }

      const mappedLabel = this.classLabels[label + 1]
      if (mappedLabel) {
        label = mappedLabel[0]
      } else {
        continue
      }

      if (confidence < confidenceThreshold) {
        continue
      }

      const offsetX = offsetMap.data[0 * nDetections + i]
      const offsetY = offsetMap.data[1 * nDetections + i]

      const x = ((Math.floor(i % mapWidth) + offsetX) * heatmapScaling) / scalingX
      const y = ((Math.floor(i / mapWidth) + offsetY) * heatmapScaling) / scalingY

      const w = (sizeMap.data[0 * nDetections + i] * heatmapScaling) / scalingX
      const h = (sizeMap.data[1 * nDetections + i] * heatmapScaling) / scalingY

      let rotation = 0
      if (rotationRegression) {
        rotation = Math.atan2(sizeMap.data[2 * nDetections + i], sizeMap.data[3 * nDetections + i])
      }

      // detections are returned in relative image coordinates
      detections.push({
        label: label,
        confidence: confidence,
        x: x,
        y: y,
        width: w,
        height: h,
        orientation: rotation,
      })
    }

    detections = nonMaximumSuppression(detections, {})

    for (const detection of detections) {
      detection.x /= width
      detection.y /= height
      detection.width /= width
      detection.height /= height
    }

    return detections
  }

  static async fromUrl(url) {
    const response = await fetch(url, {
      method: 'GET',
      mode: 'cors',
    })
    const zipBlob = await response.blob()
    return Inference.fromZip(zipBlob)
  }

  static async fromZip(zipBlob) {
    try {
      const zip = new JSZip()
      const zipObj = await zip.loadAsync(zipBlob)
      const parametersJson = await zipObj.file('parameters.json').async('string')
      const parameters = JSON.parse(parametersJson)
      const classLabelJson = await zipObj.file('class_label.json').async('string')
      const classLabels = JSON.parse(classLabelJson)
      const onnxModel = await zipObj.file('model.onnx').async('uint8array')
      return Inference.initialize(parameters, onnxModel, classLabels)
    } catch (e) {
      throw new Error('Error while extracting AI model from ZIP-file.')
    }
  }

  static async initialize(parameters, onnxModel, classLabels) {
    const inference = new Inference()
    inference.modelType = parameters.network_type
    if (!Inference.SUPPORTED_MODEL_TYPES.includes(inference.modelType)) {
      throw new Error(`Unsupported model type "${inference.modelType}".
          Supported types: ${Inference.SUPPORTED_MODEL_TYPES}`)
    }

    inference.classLabels = classLabels
    inference.inputWidth = parameters.parameters.dataset.input_width
    inference.inputHeight = parameters.parameters.dataset.input_height
    inference.inputChannels = 3

    if (inference.modelType === 'detection') {
      inference.heatmapScaling = parameters.parameters.detection.output_scaling.scale
    }

    inference.inferenceSession = await ort.InferenceSession.create(onnxModel)

    return inference
  }
}

export default Inference
