//----------------------------------------------------------------------------------
// File:        CDx11NGXVSR.cpp
// SDK Version: 1.0.2
//
// SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: LicenseRef-NvidiaProprietary
//
// NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
// property and proprietary rights in and to this material, related
// documentation and any modifications thereto. Any use, reproduction,
// disclosure or distribution of this material and related documentation
// without an express license agreement from NVIDIA CORPORATION or
// its affiliates is strictly prohibited.
//
//----------------------------------------------------------------------------------

 ///////////////////////////////////////////
// CDx11NGXVSR.cpp
// This is a wrapper class to simplify adding NGX VSR to a DX11 app
// see .h for info

#include <d3d11_4.h>

#define NGX_CLASS_USE
#include "CDx11NGXVSR.h"

HRESULT CDx11NGXVSR::CreateFeature(ID3D11Device* pD3DDevice)
{
    HRESULT hr = S_OK;
    // default to false until creation is done
    m_bNGXInitialized   = false;

    // init NGX SDK
    NVSDK_NGX_Result Status = NVSDK_NGX_D3D11_Init(APP_ID, APP_PATH, pD3DDevice);
    if (NVSDK_NGX_FAILED(Status)) return E_FAIL;

    // Get NGX parameters interface (managed and released by NGX)
    Status = NVSDK_NGX_D3D11_GetCapabilityParameters(&m_ngxParameters);
    if (NVSDK_NGX_FAILED(Status)) return E_FAIL;

    // Now check if VSR is available on the system
    int VSRAvailable = 0;
    NVSDK_NGX_Result ResultVSR = m_ngxParameters->Get(NVSDK_NGX_Parameter_VSR_Available, &VSRAvailable);
    if (!VSRAvailable) return E_FAIL;

    m_pD3D11Device = pD3DDevice;
    m_pD3D11Device->AddRef();

    m_pD3D11Device->GetImmediateContext(&m_pD3D11DeviceContext);

    hr = m_pD3D11DeviceContext->QueryInterface(__uuidof(ID3D10Multithread), (void**)&m_pMultiThread);
    if (SUCCEEDED(hr))
    {
        m_pMultiThread->SetMultithreadProtected(TRUE);
        m_pMultiThread->Enter();
    }

    // Create the VSR feature instance 
    NVSDK_NGX_Feature_Create_Params VSRCreateParams = {};
    ResultVSR = NGX_D3D11_CREATE_VSR_EXT(m_pD3D11DeviceContext, &m_VSRFeature, m_ngxParameters, &VSRCreateParams);

    if (NVSDK_NGX_FAILED(ResultVSR))
    {
        return E_FAIL;
    }

    if (m_pMultiThread)
    {
        m_pMultiThread->Leave();
    }

    m_bNGXInitialized = true;
    return S_OK;
}

// call to apply VSR from source to dest
// input must be DXGI_FORMAT_R8G8B8A8_UNORM or DXGI_FORMAT_B8G8R8A8_UNORM
HRESULT CDx11NGXVSR::EvaluateFeature(ID3D11Texture2D* Output, RECT RectOutput,
                                       ID3D11Texture2D* Input, RECT RectInput,
                                       int Quality)
{
    if (!m_bNGXInitialized)
    {
        return E_FAIL;
    }
    HRESULT hr = S_OK;
    D3D11_TEXTURE2D_DESC inDesc = {};
    D3D11_TEXTURE2D_DESC outDesc = {};
    // check formats
    {
        // check input is DXGI_FORMAT_R8G8B8A8_UNORM or DXGI_FORMAT_B8G8R8A8_UNORM
        Input->GetDesc(&inDesc);
        if (inDesc.Format != DXGI_FORMAT_R8G8B8A8_UNORM && inDesc.Format != DXGI_FORMAT_B8G8R8A8_UNORM)
        {
            return E_INVALIDARG;
        }
        // verify input rect is within range
        if (   RectInput.left < 0 || RectInput.left >= RectInput.right || RectInput.right  >(LONG)inDesc.Width
            || RectInput.top  < 0 || RectInput.top >= RectInput.bottom || RectInput.bottom >(LONG)inDesc.Height)
        {
            return E_INVALIDARG;
        }
        // check output is DXGI_FORMAT_R8G8B8A8_UNORM or DXGI_FORMAT_B8G8R8A8_UNORM
        Output->GetDesc(&outDesc);
        if (outDesc.Format != DXGI_FORMAT_R8G8B8A8_UNORM && outDesc.Format != DXGI_FORMAT_B8G8R8A8_UNORM)
        {
            return E_INVALIDARG;
        }
        // verify output rect is within range
        if (   RectOutput.left < 0 || RectOutput.left >= RectOutput.right || RectOutput.right  >(LONG)outDesc.Width
            || RectOutput.top  < 0 || RectOutput.top >= RectOutput.bottom || RectOutput.bottom >(LONG)outDesc.Height)
        {
            return E_INVALIDARG;
        }

        // The NGX dst surface must be created with BIND_UNORDERED_ACCESS, which swap buffers are not.
        // check for UNORDERED_ACCESS
        m_bNGXInitializedDstTmp = !(outDesc.BindFlags & D3D11_BIND_UNORDERED_ACCESS);
 
        // verify DstTmp matches dest surface so copyRegion works
        if (m_bNGXInitializedDstTmp && (!m_pDstTmpNGX || outDesc.Width != m_uDstTmpWidth || outDesc.Height != m_uDstTmpHeight))
        {
            SafeRelease(m_pDstTmpNGX);
            m_uDstTmpWidth                          = outDesc.Width;
            m_uDstTmpHeight                         = outDesc.Height;
            D3D11_TEXTURE2D_DESC texture2d_desc     = { 0 };
            texture2d_desc.Width                    = m_uDstTmpWidth;
            texture2d_desc.Height                   = m_uDstTmpHeight;
            texture2d_desc.MipLevels                = 1;
            texture2d_desc.ArraySize                = 1;
            texture2d_desc.SampleDesc.Count         = 1;
            texture2d_desc.MiscFlags                = 0;
            texture2d_desc.Format                   = outDesc.Format;
            texture2d_desc.Usage                    = D3D11_USAGE_DEFAULT;
            texture2d_desc.BindFlags                = D3D11_BIND_RENDER_TARGET | D3D11_BIND_SHADER_RESOURCE | D3D11_BIND_UNORDERED_ACCESS;

            hr = m_pD3D11Device->CreateTexture2D(&texture2d_desc, NULL, &m_pDstTmpNGX);
            if (FAILED(hr)) return hr;
        }
    }

    // NGX output rect, typically 0,0,swapWidth,swapHeight
    m_NGXDstRect = RectOutput;
    // NGX input rect,  typically 0,0,decodeWidth,decodeHeight
    m_NGXSrcRect = RectInput;

    // setup VSR params
    NVSDK_NGX_D3D11_VSR_Eval_Params D3D11VsrEvalParams = {};
    D3D11VsrEvalParams.pInput                   = Input;
    D3D11VsrEvalParams.pOutput                  = m_bNGXInitializedDstTmp ? m_pDstTmpNGX : Output;;
    D3D11VsrEvalParams.InputSubrectBase.X       = RectInput.left;
    D3D11VsrEvalParams.InputSubrectBase.Y       = RectInput.top;
    D3D11VsrEvalParams.InputSubrectSize.Width   = RectInput.right - RectInput.left;
    D3D11VsrEvalParams.InputSubrectSize.Height  = RectInput.bottom - RectInput.top;
    D3D11VsrEvalParams.OutputSubrectBase.X      = RectOutput.left;
    D3D11VsrEvalParams.OutputSubrectBase.Y      = RectOutput.top;
    D3D11VsrEvalParams.OutputSubrectSize.Width  = RectOutput.right - RectOutput.left;
    D3D11VsrEvalParams.OutputSubrectSize.Height = RectOutput.bottom - RectOutput.top;
    D3D11VsrEvalParams.QualityLevel             = (NVSDK_NGX_VSR_QualityLevel) Quality;


    if (m_pMultiThread)
    {
        m_pMultiThread->Enter();
    }

    NVSDK_NGX_Result ResultVSR = NGX_D3D11_EVALUATE_VSR_EXT(m_pD3D11DeviceContext, m_VSRFeature, m_ngxParameters, &D3D11VsrEvalParams);
    if (NVSDK_NGX_FAILED(ResultVSR))
    {
        return E_FAIL;
    }
    if (m_bNGXInitializedDstTmp)
    {
        m_pD3D11DeviceContext->CopySubresourceRegion(Output, 0, 0, 0, 0, m_pDstTmpNGX, 0, NULL);
    }

    if (m_pMultiThread)
    {
        m_pMultiThread->Leave();
    }

    return hr;
}


// ReleaseFeature needs to be called before the class destructor 
// as the runtime destructor will delete a critical section used in NVSDK_NGX_D3D11_Shutdown1
void CDx11NGXVSR::ReleaseFeature()
{
    if (m_bNGXInitialized)
    {
        NVSDK_NGX_D3D11_ReleaseFeature(m_VSRFeature);
        m_VSRFeature = nullptr;
        NVSDK_NGX_D3D11_Shutdown1(m_pD3D11Device);
        NVSDK_NGX_D3D11_DestroyParameters(m_ngxParameters);
        m_bNGXInitialized = false;
    }
    SafeRelease(m_pDstTmpNGX);
    SafeRelease(m_pMultiThread);
    SafeRelease(m_pD3D11DeviceContext);
    SafeRelease(m_pD3D11Device);
}

