/**
 * @fileoverview This library provides a 3D drawing utility on a NxNxN grid in a
 * 1x1x1 space.
 */

import {NormalizedLandmark} from 'google3/third_party/mediapipe/web/solutions/index';
import * as THREE from 'three';

/**
 * A connection between two landmarks
 */
type Connection = number[];

/**
 * A list of connections between landmarks
 */
type ConnectionList = Connection[];

/**
 * An interface for specifying colors for lists (e.g. landmarks and connections)
 */
type ColorMap<T> = Array<{color: string | undefined; list: T[]}>;

/**
 * An interface for containing number labels and data about them.
 */
interface NumberLabel {
  element: HTMLSpanElement;
  position: THREE.Vector3;
  value: number;
}

/**
 * Configuration for the landmark grid.
 */
export interface LandmarkGridConfig {
  axesColor?: number;
  axesWidth?: number;
  backgroundColor?: number;
  /**
   * The "centered" attribute describes whether the grid should use the center
   * of the bounding box of the landmarks as the origin.
   */
  centered?: boolean;
  connectionColor?: number;
  connectionWidth?: number;
  definedColors?: Array<{name: string; value: number;}>;
  /**
   * The "fitToGrid" attribute describes whether the grid should dynamically
   * resize based on the landmarks given.
   */
  fitToGrid?: boolean;
  labelPrefix?: string;
  labelSuffix?: string;
  landmarkColor?: number;
  landmarkSize?: number;
  margin?: number;
  minVisibility?: number;
  nonvisibleLandmarkColor?: number;
  numCellsPerAxis?: number;
  /**
   * The "range" attribue describes the default numerical boundaries of the
   * grid. The grid ranges from [-range, range] on every axis.
   */
  range?: number;
  showHidden?: boolean;
}


const DEFAULT_CONFIG: LandmarkGridConfig = {
  axesColor: 0xffffff,
  axesWidth: 2,
  backgroundColor: 0,
  centered: false,
  connectionColor: 0x00ffff,
  connectionWidth: 3,
  definedColors: [],
  fitToGrid: false,
  labelPrefix: '',
  labelSuffix: '',
  landmarkSize: 3,
  landmarkColor: 0xaaaaaa,
  margin: 0,
  minVisibility: .65,
  nonvisibleLandmarkColor: 0xff7777,
  numCellsPerAxis: 3,
  range: 1,
  showHidden: true,
};

const ORIGIN = new THREE.Vector3();
const PAUSE_SRC =
    'https://fonts.gstatic.com/s/i/googlematerialicons/pause/v14/white-24dp/1x/gm_pause_white_24dp.png';
const PLAY_SRC =
    'https://fonts.gstatic.com/s/i/googlematerialicons/play_arrow/v14/white-24dp/1x/gm_play_arrow_white_24dp.png';
const HIDDEN_MATERIAL = new THREE.Material();
HIDDEN_MATERIAL.visible = false;

/**
 * This class makes a canvas instance where points can be drawn in a NxNxN grid
 * in a 1x1x1 space.
 */
export class LandmarkGrid {
  private readonly camera!: THREE.PerspectiveCamera;
  private readonly renderer!: THREE.WebGLRenderer;
  private readonly scene!: THREE.Scene;
  private readonly distance: number = 150;
  private readonly size: number = 100;
  private readonly labels!:
      {x: NumberLabel[]; y: NumberLabel[]; z: NumberLabel[];};
  private readonly landmarkGroup: THREE.Group;
  private readonly connectionGroup: THREE.Group;
  private readonly container: HTMLDivElement;
  private readonly origin: THREE.Vector3;
  private axesMaterial!: THREE.Material;
  private connectionMaterial!: THREE.Material;
  private definedColors!: {[key: string]: THREE.Material;};
  private gridMaterial!: THREE.Material;
  private isRotating: boolean = true;
  private isVisible!: (e: NormalizedLandmark) => boolean;
  private landmarkGeometry!: THREE.BufferGeometry;
  private landmarkMaterial!: THREE.Material;
  private nonvisibleMaterial!: THREE.Material;
  private rotation: number = 0;
  private rotationSpeed: number = Math.PI / 180;
  private showHidden!: boolean;
  private disposeQueue: THREE.BufferGeometry[] = [];
  private removeQueue: THREE.Object3D[] = [];
  private landmarks: NormalizedLandmark[] = [];
  private fitToGrid!: boolean;
  private sizeWhenFitted!: number;
  private numCellsPerAxis!: number;
  private labelSuffix!: string;
  private labelPrefix!: string;
  private centered!: boolean;
  private range!: number;

  constructor(parent: HTMLElement, config: LandmarkGridConfig = {}) {
    this.container = document.createElement('div');
    this.container.classList.add('landmark-grid-js');

    const canvas = document.createElement('canvas');
    this.container.appendChild(canvas);
    parent.appendChild(this.container);
    const parentBox = parent.getBoundingClientRect();

    this.setConfig({...DEFAULT_CONFIG, ...config});

    this.addPausePlay(this.container);

    this.camera =
        new THREE.PerspectiveCamera(75, parentBox.width / parentBox.height, 1);
    this.camera.position.x = this.distance;
    this.camera.lookAt(ORIGIN);

    this.renderer =
        new THREE.WebGLRenderer({canvas, alpha: true, antialias: true});
    this.renderer.setClearColor(new THREE.Color(0), .5);
    this.renderer.setSize(
        Math.floor(parentBox.width), Math.floor(parentBox.height));
    window.addEventListener('resize', () => {
      const box = this.container.getBoundingClientRect();
      this.renderer.setSize(Math.floor(box.width), Math.floor(box.height));
    });

    this.scene = new THREE.Scene();

    this.drawAxes();
    this.labels = this.createAxesLabels();
    this.landmarkGroup = new THREE.Group();
    this.scene.add(this.landmarkGroup);
    this.connectionGroup = new THREE.Group();
    this.scene.add(this.connectionGroup);

    this.origin = new THREE.Vector3();

    this.setMouseDrag();

    this.requestFrame();
  }

  private createAxesLabels() {
    const labels = {
      x: [] as NumberLabel[],
      y: [] as NumberLabel[],
      z: [] as NumberLabel[],
    };

    const HALF_SIZE = this.size / 2;
    for (let i = 0; i < this.numCellsPerAxis; i++) {
      // X labels
      // This for vector adds one to the count as it covers numCellsPerAxis-1
      // points on the x-axis. The point not covered is where the y-axis meets
      // the x-axis.
      const xValue = ((i + 1) / this.numCellsPerAxis - .5) * this.range;
      labels.x.push({
        position: new THREE.Vector3(
            (i + 1) / this.numCellsPerAxis * this.size - HALF_SIZE, -HALF_SIZE,
            HALF_SIZE),
        element: this.createLabel(xValue),
        value: xValue
      });
      // Z labels
      // This vector covers numCellsPerAxis-1 points on the z-axis. The point
      // not covered is where the z-axis meets the x-axis.
      const zValue = (i / this.numCellsPerAxis - .5) * this.range;
      labels.z.push({
        position: new THREE.Vector3(
            HALF_SIZE, -HALF_SIZE,
            i / this.numCellsPerAxis * this.size - HALF_SIZE),
        element: this.createLabel(zValue),
        value: zValue
      });
    }
    // Y labels
    // This for loop covers all points on the y-axis
    for (let i = 0; i <= this.numCellsPerAxis; i++) {
      const yValue = (i / this.numCellsPerAxis - .5) * this.range;
      labels.y.push({
        position: new THREE.Vector3(
            -HALF_SIZE, i / this.numCellsPerAxis * this.size - HALF_SIZE,
            HALF_SIZE),
        element: this.createLabel(yValue),
        value: yValue,
      });
    }

    return labels;
  }

  private createLabel(value: number) {
    const el = document.createElement('span');
    el.classList.add('landmark-label-js');
    this.setLabel(el, value);
    this.container.appendChild(el);
    return el;
  }

  private setLabel(el: HTMLSpanElement, value: number) {
    el.textContent =
        this.labelPrefix + value.toPrecision(2).toString() + this.labelSuffix;
  }

  private drawAxes() {
    const axes = new THREE.Group();
    const HALF_SIZE = this.size / 2;

    const grid = this.makeGrid(this.size, this.numCellsPerAxis);
    const xGrid = grid;
    const yGrid = grid.clone();
    const zGrid = grid.clone();

    xGrid.translateX(-HALF_SIZE);
    xGrid.rotateY(Math.PI / 2);
    axes.add(xGrid);

    yGrid.translateY(-HALF_SIZE);
    yGrid.rotateX(Math.PI / 2);
    axes.add(yGrid);

    zGrid.translateZ(-HALF_SIZE);
    axes.add(zGrid);


    const border = new THREE.BufferGeometry().setFromPoints([
      new THREE.Vector3(-HALF_SIZE, HALF_SIZE, HALF_SIZE),
      new THREE.Vector3(-HALF_SIZE, -HALF_SIZE, HALF_SIZE),
      new THREE.Vector3(HALF_SIZE, -HALF_SIZE, HALF_SIZE),
      new THREE.Vector3(HALF_SIZE, -HALF_SIZE, -HALF_SIZE),
      new THREE.Vector3(HALF_SIZE, HALF_SIZE, -HALF_SIZE),
      new THREE.Vector3(-HALF_SIZE, HALF_SIZE, -HALF_SIZE),
      new THREE.Vector3(-HALF_SIZE, HALF_SIZE, HALF_SIZE)
    ]);
    axes.add(new THREE.Line(border, this.axesMaterial));


    this.scene.add(axes);
  }

  private render() {
    this.renderer.render(this.scene, this.camera);
    this.setLabels();
  }

  private requestFrame() {
    window.requestAnimationFrame(() => {
      if (this.isRotating) {
        this.rotation += this.rotationSpeed;
        this.camera.position.x = Math.sin(this.rotation) * this.distance;
        this.camera.position.z = Math.cos(this.rotation) * this.distance;
      }
      this.camera.lookAt(ORIGIN);
      this.render();
    });
  }

  private colorLandmarks(landmarks?: number[], colorName?: string) {
    const color =
        colorName ? this.definedColors[colorName] : this.connectionMaterial;
    const meshList = this.landmarkGroup.children as THREE.Mesh[];

    if (landmarks) {
      for (const landmarkIndex of landmarks) {
        if (!this.isVisible(this.landmarks[landmarkIndex])) continue;
        meshList[landmarkIndex].material = color;
      }
    } else {
      for (let i = 0; i < this.landmarks.length; i++) {
        if (!this.isVisible(this.landmarks[i])) continue;
        meshList[i].material = color;
      }
    }
  }

  updateLandmarks(
      landmarks: NormalizedLandmark[],
      colorConnections: ConnectionList|ColorMap<Connection> = [],
      colorLandmarks?: ColorMap<number>) {
    this.landmarkGroup.clear();
    this.connectionGroup.clear();
    this.clearResources();

    if (landmarks.length === 0) {
      this.landmarks = [];
      return;
    }
    this.landmarks = landmarks.map(this.copyLandmark);

    // Convert connections to ColorList if not already
    let connections: ColorMap<Connection> = [];
    if (colorConnections.length > 0 &&
        !colorConnections[0].hasOwnProperty('color')) {
      connections =
          [{color: undefined, list: colorConnections as ConnectionList}];
    } else {
      connections = colorConnections as ColorMap<Connection>;
    }

    const visibleLandmarks = this.landmarks.filter((e) => this.isVisible(e));
    const centeredLandmarks =
        visibleLandmarks.length === 0 ? this.landmarks : visibleLandmarks;
    if (this.centered) {
      this.centerLandmarks(centeredLandmarks);
    }

    // Fit to grid if necessary
    let scalingFactor = 1;
    if (this.fitToGrid) {
      const rawScalingFactor = this.getFitToGridFactor(centeredLandmarks);
      const RESCALE = .5;
      const numOfRescaleSteps =
          Math.ceil((1 / rawScalingFactor - 1) * (this.range / 2) / RESCALE);
      scalingFactor = 1 / (numOfRescaleSteps * RESCALE / (this.range / 2) + 1);

      for (const landmark of this.landmarks) {
        landmark.x *= scalingFactor;
        landmark.y *= scalingFactor;
        landmark.z *= scalingFactor;
      }
    }

    for (const label of this.labels.x) {
      this.setLabel(
          label.element, (label.value - this.origin.x) / scalingFactor);
    }
    for (const label of this.labels.y) {
      this.setLabel(
          label.element, (label.value - this.origin.y) / scalingFactor);
    }
    for (const label of this.labels.z) {
      this.setLabel(
          label.element, (label.value - this.origin.z) / scalingFactor);
    }

    const landmarkVectors: THREE.Vector3[] =
        this.landmarks.map(e => this.landmarkToVector(e));

    // Pose connections
    for (const connection of connections) {
      this.drawConnections(landmarkVectors, connection.list, connection.color);
    }

    // Pose landmarks
    for (let i = 0; i < this.landmarks.length; i++) {
      const visible = this.isVisible(this.landmarks[i]);
      let nonvisibleMaterial = this.nonvisibleMaterial;
      if (!this.showHidden && !visible) {
        nonvisibleMaterial = HIDDEN_MATERIAL;
      }

      const sphere = new THREE.Mesh(
          this.landmarkGeometry,
          visible ? this.landmarkMaterial : nonvisibleMaterial);
      this.removeQueue.push(sphere);
      const point = landmarkVectors[i];
      sphere.position.add(point);
      this.landmarkGroup.add(sphere);
    }

    // Color special landmarks
    if (colorLandmarks) {
      for (const colorDef of colorLandmarks) {
        this.colorLandmarks(colorDef.list, colorDef.color);
      }
    }

    this.requestFrame();
  }

  private drawConnections(
      landmarks: THREE.Vector3[], connections: ConnectionList,
      colorName?: string) {
    const color =
        colorName ? this.definedColors[colorName] : this.connectionMaterial;

    const lines = [];
    for (const connection of connections) {
      if (!this.showHidden &&
          (!this.isVisible(this.landmarks[connection[0]]) ||
           !this.isVisible(this.landmarks[connection[1]]))) {
        continue;
      }

      lines.push(landmarks[connection[0]]);
      lines.push(landmarks[connection[1]]);
    }
    const geometry = new THREE.BufferGeometry().setFromPoints(lines);
    this.disposeQueue.push(geometry);
    const wireframe = new THREE.LineSegments(geometry, color);
    this.removeQueue.push(wireframe);
    this.connectionGroup.add(wireframe);
  }

  private landmarkToVector(point: NormalizedLandmark): THREE.Vector3 {
    // The Y and Z orientations are flipped in three.js compared to the y and z
    // orientations in solutions
    return new THREE.Vector3(
        point.x * this.size / this.range, -point.y * this.size / this.range,
        -point.z * this.size / this.range);
  }

  private makeGrid(size: number, numSteps: number) {
    const grid = new THREE.Group();

    const plane = new THREE.PlaneGeometry(size, size);
    const edges = new THREE.EdgesGeometry(plane);
    const wireframe = new THREE.LineSegments(edges, this.gridMaterial);
    grid.add(wireframe);

    const stepPlaneSize = size / numSteps;
    const stepPlane = new THREE.PlaneGeometry(stepPlaneSize, stepPlaneSize);
    const stepEdges = new THREE.EdgesGeometry(stepPlane);
    const corner = -size / 2 + stepPlaneSize / 2;
    for (let i = 0; i < numSteps; i++) {
      for (let j = 0; j < numSteps; j++) {
        const stepFrame = new THREE.LineSegments(stepEdges, this.gridMaterial);
        stepFrame.translateX(corner + i * stepPlaneSize);
        stepFrame.translateY(corner + j * stepPlaneSize);
        grid.add(stepFrame);
      }
    }

    return grid;
  }

  private setConfig(config: LandmarkGridConfig) {
    this.landmarkMaterial =
        new THREE.MeshBasicMaterial({color: config.landmarkColor!});
    this.landmarkGeometry = new THREE.SphereGeometry(config.landmarkSize!);
    this.nonvisibleMaterial =
        new THREE.MeshBasicMaterial({color: config.nonvisibleLandmarkColor!});
    this.axesMaterial = new THREE.LineBasicMaterial(
        {color: config.axesColor!, linewidth: config.axesWidth!});
    this.gridMaterial = new THREE.LineBasicMaterial({color: 0x999999});
    this.connectionMaterial = new THREE.LineBasicMaterial(
        {color: config.connectionColor!, linewidth: config.connectionWidth!});
    this.isVisible = (e: NormalizedLandmark) => (
        (e.visibility !== undefined) && (e.visibility > config.minVisibility!));
    this.definedColors = {};
    for (const color of config.definedColors!) {
      this.definedColors[color.name] = new THREE.LineBasicMaterial(
          {color: color.value, linewidth: config.connectionWidth});
    }
    this.showHidden = config.showHidden!;
    this.fitToGrid = config.fitToGrid!;
    this.sizeWhenFitted = (1 - 2 * config.margin!);
    this.numCellsPerAxis = config.numCellsPerAxis!;
    this.labelSuffix = config.labelSuffix!;
    this.labelPrefix = config.labelPrefix!;
    this.centered = config.centered!;
    this.range = config.range!;
  }

  private setMouseDrag() {
    const el = this.renderer.domElement;
    const elWidth = el.getBoundingClientRect().width;
    el.onmousedown = (event: MouseEvent) => {
      event.preventDefault();
      const speed = this.rotationSpeed;
      const origRotation = this.rotation;
      this.rotationSpeed = 0;

      const mouseMove = (e: MouseEvent) => {
        e.preventDefault();
        const rotation = 2 * Math.PI * (event.offsetX - e.offsetX) / elWidth;
        const distance =
            Math.hypot(this.camera.position.x, this.camera.position.z);
        this.rotation = origRotation + rotation;
        this.camera.position.x = Math.sin(this.rotation) * distance;
        this.camera.position.z = Math.cos(this.rotation) * distance;
      };
      const mouseUp = (e: MouseEvent) => {
        e.preventDefault();
        el.removeEventListener('mousemove', mouseMove);
        this.rotationSpeed = speed;
        el.removeEventListener('mouseup', mouseUp);
      };

      el.addEventListener('mousemove', mouseMove);
      document.addEventListener('mouseup', mouseUp);
    };
  }

  private addPausePlay(parent: HTMLElement) {
    const button = document.createElement('img');
    button.classList.add('controls');
    button.src = PAUSE_SRC;

    button.onclick = () => {
      if (this.isRotating) {
        button.src = PLAY_SRC;
        this.isRotating = false;
      } else {
        button.src = PAUSE_SRC;
        this.isRotating = true;
      }
    };

    parent.appendChild(button);
  }

  private getFitToGridFactor(landmarks?: NormalizedLandmark[]) {
    if (!landmarks) {
      landmarks = this.landmarks;
    }
    if (landmarks.length === 0) {
      return 1;
    }

    let factor = Infinity;
    for (let i = 0; i < landmarks.length; i++) {
      const maxNum = Math.max(
          Math.abs(landmarks[i].x), Math.abs(landmarks[i].y),
          Math.abs(landmarks[i].z));
      factor = Math.min(factor, (this.range / 2) / maxNum);
    }
    return factor * this.sizeWhenFitted;
  }

  private getCanvasPosition(position: THREE.Vector3): THREE.Vector3 {
    const size = this.renderer.domElement.getBoundingClientRect();
    const vector = position.clone().project(this.camera);
    vector.x = Math.round((0.5 + vector.x / 2) * size.width);
    vector.y = Math.round((0.5 - vector.y / 2) * size.height);
    vector.z = 0;
    return vector;
  }

  private setLabels() {
    for (const pair of this.labels.x) {
      const position = this.getCanvasPosition(pair.position);
      pair.element.style.left = `${position.x}px`;
      pair.element.style.top = `${position.y}px`;
    }
    for (const pair of this.labels.y) {
      const position = this.getCanvasPosition(pair.position);
      pair.element.style.left = `${position.x}px`;
      pair.element.style.top = `${position.y}px`;
    }
    for (const pair of this.labels.z) {
      const position = this.getCanvasPosition(pair.position);
      pair.element.style.left = `${position.x}px`;
      pair.element.style.top = `${position.y}px`;
    }
  }

  private clearResources() {
    for (const e of this.removeQueue) {
      if (e.parent) e.parent.remove(e);
    }
    this.removeQueue = [];
    for (const e of this.disposeQueue) {
      e.dispose();
    }
    this.disposeQueue = [];
  }

  private centerLandmarks(landmarks: NormalizedLandmark[]) {
    if (landmarks.length === 0) {
      return;
    }

    let maxX = landmarks[0].x, minX = landmarks[0].x, maxY = landmarks[0].y,
        minY = landmarks[0].y, maxZ = landmarks[0].z, minZ = landmarks[0].z;
    for (let i = 1; i < landmarks.length; i++) {
      const landmark = landmarks[i];
      maxX = Math.max(maxX, landmark.x);
      maxY = Math.max(maxY, landmark.y);
      maxZ = Math.max(maxZ, landmark.z);
      minX = Math.min(minX, landmark.x);
      minY = Math.min(minY, landmark.y);
      minZ = Math.min(minZ, landmark.z);
    }
    const centerX = (maxX + minX) / 2;
    const centerY = (maxY + minY) / 2;
    const centerZ = (maxZ + minZ) / 2;
    for (let i = 0; i < this.landmarks.length; i++) {
      this.landmarks[i].x -= centerX;
      this.landmarks[i].y -= centerY;
      this.landmarks[i].z -= centerZ;
    }

    this.origin.set(centerX, centerY, centerZ);
  }

  private copyLandmark(e: NormalizedLandmark): NormalizedLandmark {
    return {x: e.x, y: e.y, z: e.z, visibility: e.visibility};
  }
}

goog.exportSymbol('LandmarkGrid', LandmarkGrid);
