import Detector from './Detector';
import * as tf from '@tensorflow/tfjs';
import cv from '../third_party/opencv.js';
import { Pose, numHeatMapsChannel, numPafsChannels } from './Pose';
import { getSetting } from '../utils/Settings';

const modelWidth = 192;
const modelHeight = 144;
const outputWidth = modelWidth >> 3;
const outputHeight = modelHeight >> 3;

// For avatars.
const avatarSize = 128;
const headBoxMargin = 0.6; // We will add 0.5 of the headbox dimensions to it.

class PoseDetector extends Detector {
    /**
     *
     * @param {tf.LayersModel} model
     * @param {boolean} mirrored
     */
    constructor(model, mirrored) {
        super(modelWidth, modelHeight, mirrored);
        this._model = model;
        this._reusableHeatMapsMat = new cv.Mat(
            outputHeight,
            outputWidth,
            cv.CV_32FC1 + (cv.CV_32FC2 - cv.CV_32FC1) * (numHeatMapsChannel - 1)
        );
        this._reusablePafsMat = new cv.Mat(
            outputHeight,
            outputWidth,
            cv.CV_32FC1 + (cv.CV_32FC2 - cv.CV_32FC1) * (numPafsChannels - 1)
        );
        this._reusableResizedHeatMapsMat = new cv.Mat(
            modelHeight,
            modelWidth,
            cv.CV_32FC1 + (cv.CV_32FC2 - cv.CV_32FC1) * (numHeatMapsChannel - 1)
        );
        this._reusableResizedPafsMat = new cv.Mat(
            modelHeight,
            modelWidth,
            cv.CV_32FC1 + (cv.CV_32FC2 - cv.CV_32FC1) * (numPafsChannels - 1)
        );
        this._processingSize = new cv.Size(modelWidth, modelHeight);

        // NOTE(matthew_chan): Ideally, these should all be OffscreenCanvas.
        this._fullFrame = document.createElement('canvas');
        this._avatarFrame = document.createElement('canvas');
        this._avatarFrame.width = avatarSize;
        this._avatarFrame.height = avatarSize;
    }

    /**
     * Predicts the classes in the given canvas.
     * The debug canvas is used to render debug information.
     * @param {HTMLCanvasElement} canvas
     * @param {Boolean} detectExtraInfo
     * @param {HTMLCanvasElement} [debugCanvas]
     * @returns {Promise}
     */
    async predict(canvas, detectExtraInfo, debugCanvas) {
        const context = this._frameCanvasContext;
        const scaleX = canvas.width / this._imageWidth;
        const scaleY = canvas.height / this._imageHeight;
        context.drawImage(canvas, 0, 0, canvas.width, canvas.height, 0, 0, this._imageWidth, this._imageHeight);

        if (detectExtraInfo) {
            // In this case, we need to save the original pixels so that we can get back the head image.
            this._fullFrame.width = canvas.width;
            this._fullFrame.height = canvas.height;
            const fullFrameContext = this._fullFrame.getContext('2d');
            if (this._mirrored) {
                fullFrameContext.setTransform(-1, 0, 0, 1, canvas.width, 0);
            }
            fullFrameContext.drawImage(canvas, 0, 0);
        }

        let input = tf.tidy(() => {
            const offset = tf.scalar(127.5);
            const imageData = tf.browser.fromPixels(this._frameCanvas);
            return imageData.toFloat().sub(offset).div(offset).expandDims();
        });
        if (getSetting('perfTuning.disablePrediction')) {
            return { detection: [], extraInfo: null };
        }
        const prediction = this._model.predict(input);

        if (getSetting('perfTuning.disablePose')) {
            return { detection: [], extraInfo: null };
        }

        await Promise.all([
            this._populateMat(prediction[0], this._reusableHeatMapsMat, this._reusableResizedHeatMapsMat),
            this._populateMat(prediction[1], this._reusablePafsMat, this._reusableResizedPafsMat),
        ]);

        const poses = Pose.decodeFromMLResults(
            this._reusableResizedHeatMapsMat,
            this._reusableResizedPafsMat,
            scaleX,
            scaleY,
            debugCanvas
        );

        tf.dispose(input);
        tf.dispose(prediction);

        let extraInfo = null;
        do {
            if (!detectExtraInfo || poses.length === 0) {
                break;
            }
            const headBox = poses[0].headBox;
            if (headBox.w < 5 || headBox.h < 5) {
                // If the width or height are too small, it is probably an ill detection.
                break;
            }
            // If the headBox is too close to the frame boundary, we will bail out.
            const safeMargin = Math.min(this._fullFrame.width, this._fullFrame.height) * 0.1;
            if (
                headBox.x < safeMargin ||
                headBox.y < safeMargin ||
                headBox.x + headBox.w >= this._fullFrame.width - safeMargin ||
                headBox.y + headBox.h >= this._fullFrame.height - safeMargin
            ) {
                break;
            }

            const avatarContext = this._avatarFrame.getContext('2d');
            avatarContext.fillStyle = 'gray';
            avatarContext.fillRect(0, 0, avatarSize, avatarSize); // Fill it with something black.

            const headCenterX = headBox.x + headBox.w * 0.5;
            const headCenterY = headBox.y + headBox.h * 0.5;
            const halfOrigSize = Math.max(headBox.w, headBox.h) * (1 + 2 * headBoxMargin) * 0.5;

            // Now, we are going to copy the portion of origBox, that is still valid in the full frame rect, to our avatar.
            const sourceMinX = Math.max(0, Math.round(headCenterX - halfOrigSize));
            const sourceMaxX = Math.min(this._fullFrame.width, Math.round(headCenterX + halfOrigSize));
            const sourceMinY = Math.max(0, Math.round(headCenterY - halfOrigSize));
            const sourceMaxY = Math.min(this._fullFrame.height, Math.round(headCenterY + halfOrigSize));
            // Now translate these source ranges to the corresponding coordinates in the destination space.
            const destMinX = Math.floor(avatarSize * 0.5 * (1 + (sourceMinX - headCenterX) / halfOrigSize));
            const destMaxX = Math.floor(avatarSize * 0.5 * (1 + (sourceMaxX - headCenterX) / halfOrigSize));
            const destMinY = Math.floor(avatarSize * 0.5 * (1 + (sourceMinY - headCenterY) / halfOrigSize));
            const destMaxY = Math.floor(avatarSize * 0.5 * (1 + (sourceMaxY - headCenterY) / halfOrigSize));

            avatarContext.drawImage(
                this._fullFrame,
                sourceMinX,
                sourceMinY,
                sourceMaxX - sourceMinX,
                sourceMaxY - sourceMinY,
                destMinX,
                destMinY,
                destMaxX - destMinX,
                destMaxY - destMinY
            );

            // Now convert the avatar frame into base64.
            extraInfo = {
                avatar: this._avatarFrame.toDataURL('image/jpeg'),
            };
        } while (false);

        return { detection: poses, extraInfo };
    }

    /**
     * Extract the content of the tensor into a cv.Mat.
     * @param {tf.Tensor4D} tensor
     * @param {cv.Mat} outputMat
     * @param {cv.Mat} resizedMat
     */
    async _populateMat(tensor, outputMat, resizedMat) {
        const data = await tensor.data();
        outputMat.data32F.set(data);
        cv.resize(outputMat, resizedMat, this._processingSize, 0, 0, cv.INTER_CUBIC);
    }

    dispose() {
        super.dispose();
        this._model.dispose();
    }

    /**
     * @returns {Promise<PoseDetector>}
     */
    static async create() {
        const modelName = getSetting('activity.useSmallModel') ? 'nex_pose_bc' : 'nex_pose_b';
        const model = await tf.loadLayersModel(process.env.PUBLIC_URL + `/assets/${modelName}/model.json`);
        await waitForCV();
        // console.log('NEX Pose Model loading completed.', model);

        return new PoseDetector(model, true);
    }
}

async function waitForCV() {
    return new Promise((resolve, reject) => {
        let callback = () => {
            if (cv.Mat) {
                resolve();
            }
            setTimeout(callback, 50);
        };
        callback();
    });
}

export default PoseDetector;
