<template>
  <div class="row p-0 position-relative">
    <loading-overlay v-if="loading > 0" light />
    <div v-for="(chart, i) in charts" :key="`chart-${i}`" class="p-0 col-sm-12 col-lg-6 col-xl-4">
      <div class="font-weight-700 text-center font-size-small">
        {{ chart.title }}
      </div>
      <basic-chart
        v-if="chart.xData.length > 0 && chart.yData.length > 0"
        :chart-type="chart.type"
        :x-data="chart.xData"
        :y-data="chart.yData"
        :overwrite-options="chart.overwriteOptions"
        :height="200"
      />
      <div v-else class="pt-5 pb-5 text-center text-dark">
        <small> no data available </small>
      </div>
    </div>
  </div>
</template>

<script>
import _ from 'lodash'
import Cookie from 'js-cookie'
import { mapState } from 'vuex'

import BasicChart from '@/components/charts/BasicChart'
import LoadingOverlay from '@/components/LoadingOverlay'

export default {
  name: 'ModelMetricsOverview',
  components: { LoadingOverlay, BasicChart },
  props: {
    model: {
      type: Object,
      default() {
        return {}
      },
    },
  },
  data() {
    return {
      loading: 0,
      charts: [
        {
          title: '',
          xData: [],
          yData: [],
          type: 'bar',
          overwriteOptions: {
            yAxis: {
              min: 0.0,
              max: 1.0,
            },
          },
        },
        {
          title: '',

          xData: [],
          yData: [],
          type: 'line',
          overwriteOptions: {
            yAxis: {
              min: 0.0,
              max: 1.0,
            },
          },
        },
        {
          title: '',
          xData: [],
          yData: [],
          type: 'line',
        },
      ],
    }
  },
  computed: {
    ...mapState(['classLabelMap', 'jobs']),
  },

  mounted() {
    this.getStatistics()
  },
  methods: {
    async getProgressStatistics() {
      if (!this.model) {
        return
      }

      this.loading++
      try {
        const url = `/api/models/${this.model.id}/statistics/?exclude_pr_curves=${
          this.model.network_type === 'detection'
        }`
        const response = await this.$axios({
          method: 'get',
          url: url,
        })

        const data = response.data

        {
          // bar chart for per-class precision
          let title = ''
          const xData = []
          const yData = []

          try {
            if (data.length > 0) {
              const d = data[data.length - 1]
              const statistics = d.statistics

              if (this.model.network_type === 'classification') {
                const classificationStatistics = statistics.eval.classification
                let f1 = classificationStatistics.multi_class.per_class.f1_score
                for (const classId in f1) {
                  let classLabel = this.classLabelMap[parseInt(classId)]
                  if (classLabel !== undefined) {
                    classLabel = classLabel.name
                  } else {
                    classLabel = classId
                  }
                  const value = f1[classId]
                  xData.push(classLabel)
                  yData.push(value)
                }
                title = 'F1 Score'
              } else if (this.model.network_type === 'detection') {
                // average precision chart
                let detectionStatistics = statistics['eval']['detection']
                for (let classId in detectionStatistics['average_precisions']) {
                  let classLabel = this.classLabelMap[parseInt(classId)]
                  if (classLabel !== undefined) {
                    classLabel = classLabel.name
                  } else {
                    classLabel = classId
                  }
                  const averagePrecision = detectionStatistics['average_precisions'][classId]
                  xData.push(classLabel)
                  yData.push(averagePrecision)
                }
                title = 'average precision (AP)'
              }
            }
          } catch (e) {
            console.warn(e)
            this.charts[0].title = ''
            this.charts[0].xData = []
            this.charts[0].yData = []
          }

          if (xData.length > 0 && yData.length > 0) {
            this.charts[0].title = title
            this.charts[0].xData = xData
            this.charts[0].yData = yData
          } else if (this.model.inference_data) {
            // load statistics via evaluation API
            await this.getEvalStatistics()
          }
        }

        {
          // chart for accuracy / mean average precision
          let title = ''
          let xData = []
          let yData = []

          try {
            for (const d of data) {
              const iteration = d.iteration
              const statistics = d.statistics

              if (this.model.network_type === 'classification') {
                const classificationStatistics = statistics.eval.classification
                const acc = classificationStatistics.accuracy
                xData.push(iteration)
                yData.push(acc)
                title = 'accuracy'
              } else if (this.model.network_type === 'detection') {
                // mAP chart
                let detectionStatistics = statistics['eval']['detection']
                let meanAveragePrecision = 0.0
                let count = 0
                for (let classId in detectionStatistics['average_precisions']) {
                  meanAveragePrecision += detectionStatistics['average_precisions'][classId]
                  count += 1
                }
                meanAveragePrecision /= count
                xData.push(iteration)
                yData.push(meanAveragePrecision)
                title = 'mean average precision (mAP)'
              }
            }
          } catch (e) {
            title = ''
            xData = []
            yData = []
            console.warn(e)
          }

          this.charts[1].title = title
          this.charts[1].xData = xData
          this.charts[1].yData = yData
        }

        {
          // chart for training loss
          let title = ''
          let xData = []
          let yData = []

          try {
            for (const d of data) {
              const iteration = d.iteration
              const statistics = d.statistics

              let loss = statistics['train']['scalars']['total_loss']
              if (loss === undefined) {
                loss = statistics['train']['scalars']['loss']
              }
              xData.push(iteration)
              yData.push(loss)
              title = 'loss'
            }
          } catch (e) {
            title = ''
            xData = []
            yData = []
            console.warn(e)
          }

          this.charts[2].title = title
          this.charts[2].xData = xData
          this.charts[2].yData = yData
        }
      } catch (e) {
        this.charts[2].title = ''
        this.charts[2].xData = []
        this.charts[2].yData = []
        console.warn(e)
      }
      this.loading--
    },
    async getEvalStatistics() {
      if (!this.model) {
        return
      }

      this.loading++
      try {
        const classLabels = _.clone(this.model.class_labels)
        const response = await this.$axios({
          method: 'post',
          withCredentials: true,
          headers: {
            'X-Requested-With': 'XMLHttpRequest',
            'X-CSRFToken': Cookie.get('csrftoken'),
          },
          url: `/api/inference-results/evaluate/?network_model__id=${this.model.id}`,
          data: {
            task: this.model.network_type,
            classes: classLabels,
          },
        })

        let title = ''
        const xData = []
        const yData = []

        const data = response.data

        if (data.evaluation_results.evaluated_items > 0) {
          if (this.model.network_type === 'classification') {
            const perClassF1 = data.evaluation_results['f1']['per_class']
            for (const i in classLabels) {
              const classId = classLabels[i]
              const f1 = perClassF1[i]
              let classLabel = this.classLabelMap[classId]
              if (classLabel !== undefined) {
                classLabel = classLabel.name
              } else {
                classLabel = classId
              }
              xData.push(classLabel)
              yData.push(f1)
            }
            title = 'F1 Score'
          } else if (this.model.network_type === 'detection') {
            const perClassMetrics = data.evaluation_results['iou_0.50']['per_class']
            for (const c in perClassMetrics) {
              const classId = parseInt(c)
              const metrics = perClassMetrics[c]
              let classLabel = this.classLabelMap[classId]
              if (classLabel !== undefined) {
                classLabel = classLabel.name
              } else {
                classLabel = classId
              }
              xData.push(classLabel)
              yData.push(metrics.ap)
            }
            title = 'average precision (AP)'
          }
        }

        this.charts[0].title = title
        this.charts[0].xData = xData
        this.charts[0].yData = yData
      } catch (e) {
        this.charts[0].title = ''
        this.charts[0].xData = []
        this.charts[0].yData = []
        console.warn(e)
      }
      this.loading--
    },
    getStatistics() {
      this.getProgressStatistics()
    },
  },
}
</script>

<style lang="scss" scoped>
@import '../../custom';

.model {
  display: contents;

  @extend .border-radius-md;
  //background: $primary;
  background: $white;

  padding: 0.5rem;
  cursor: pointer;

  &:hover {
    background: $gray-300;
  }

  //&:active {
  //  background: $gray-400;
  //}
}

.no-hover {
  cursor: default !important;
  opacity: 1 !important;
}

.highlight {
  background: $primary;
  color: $white !important;

  &:hover {
    background: $blue-variation;
  }

  //&:active {
  //  background: $gray-800;
  //}
}

//.class-label-square {
//  display: inline-block;
//  width: 12px;
//  height: 24px;
//  border-radius: 2px;
//  margin-right: 0.5rem;
//}
</style>
