//----------------------------------------------------------------------------------
// File:        CNGXDX12VSR.cpp
// SDK Version: 1.0.2
//
// SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: LicenseRef-NvidiaProprietary
//
// NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
// property and proprietary rights in and to this material, related
// documentation and any modifications thereto. Any use, reproduction,
// disclosure or distribution of this material and related documentation
// without an express license agreement from NVIDIA CORPORATION or
// its affiliates is strictly prohibited.
//
//----------------------------------------------------------------------------------

///////////////////////////////////////////
// CNGXDX12VSR.cpp
// This is a wrapper class to simplify adding NGX VSR to a DX12 app
// see .h for info

#include <d3d12.h>

#define NGX_CLASS_USE
#include "CDx12NGXVSR.h"

// add sync class
#include "CDx12Sync.h"

HRESULT CDx12NGXVSR::CreateFeature(ID3D12Device* pD3DDevice, UINT uGPUNodeMask, UINT uGPUVisibleNodeMask)
{
    HRESULT hr = S_OK;
    // default to false until creation is done
    m_bNGXInitialized   = false;

    m_D3D12Device = pD3DDevice;
    m_D3D12Device->AddRef();

    m_uGPUNodeMask = uGPUNodeMask;
    m_uGPUVisibleNodeMask = uGPUVisibleNodeMask;
    // used to verify commandQ and allocator are ready
    m_CmdQSyncObjNGX = new CDx12SyncObject(m_D3D12Device);

    // init NGX SDK
    NVSDK_NGX_Result Status = NVSDK_NGX_D3D12_Init(APP_ID, APP_PATH, m_D3D12Device);
    if (NVSDK_NGX_FAILED(Status)) return E_FAIL;

    // Get NGX parameters interface (managed and released by NGX)
    Status = NVSDK_NGX_D3D12_GetCapabilityParameters(&m_ngxParameters);
    if (NVSDK_NGX_FAILED(Status)) return E_FAIL;

    // Now check if VSR is available on the system
    int VSRAvailable = 0;
    NVSDK_NGX_Result ResultVSR = m_ngxParameters->Get(NVSDK_NGX_Parameter_VSR_Available, &VSRAvailable);
    if (!VSRAvailable) return E_FAIL;

    {
        // Describe and create the NGX command queue.
        D3D12_COMMAND_QUEUE_DESC queueDesc = {};
        queueDesc.Flags = D3D12_COMMAND_QUEUE_FLAG_NONE;
        queueDesc.Type = D3D12_COMMAND_LIST_TYPE_DIRECT;

        hr = m_D3D12Device->CreateCommandQueue(&queueDesc, IID_PPV_ARGS(&m_commandQueueNGX));
        if (FAILED(hr)) return hr;
    }

    // create NGX command allocator
    hr = m_D3D12Device->CreateCommandAllocator(D3D12_COMMAND_LIST_TYPE_DIRECT, IID_PPV_ARGS(&m_commandAllocatorNGX));
    if (FAILED(hr)) return hr;

    // create NGX command list
    hr = m_D3D12Device->CreateCommandList(0, D3D12_COMMAND_LIST_TYPE_DIRECT, m_commandAllocatorNGX, nullptr, IID_PPV_ARGS(&m_commandListNGX));
    if (FAILED(hr)) return hr;

    hr = m_commandListNGX->Close();
    if (FAILED(hr)) return hr;

    // make sure the commandQ and allocator are ready
    m_CmdQSyncObjNGX->SignalFence(m_commandQueueNGX);
    m_CmdQSyncObjNGX->WaitForCPUFence();

    // clear allocator and command list for the creation
    hr = m_commandAllocatorNGX->Reset();
    if (FAILED(hr)) return hr;
    hr = m_commandListNGX->Reset(m_commandAllocatorNGX, nullptr);
    if (FAILED(hr)) return hr;

     // Create the VSR feature instance 
    NVSDK_NGX_Feature_Create_Params VSRCreateParams = {};
    ResultVSR = NGX_D3D12_CREATE_VSR_EXT(m_commandListNGX, uGPUNodeMask, uGPUVisibleNodeMask, &m_VSRFeature, m_ngxParameters, &VSRCreateParams);

    hr = m_commandListNGX->Close();
    if (FAILED(hr)) return hr;

    if (NVSDK_NGX_FAILED(ResultVSR))
    {
        return E_FAIL;
    }
    // execute the command list
    {
        ID3D12CommandList* ppCommandLists[] = { m_commandListNGX };
        m_commandQueueNGX->ExecuteCommandLists(_countof(ppCommandLists), ppCommandLists);

        m_bNGXInitialized = true;
    }
    return S_OK;
}

// call to apply VSR from source to dest
// input must be DXGI_FORMAT_R8G8B8A8_UNORM or DXGI_FORMAT_B8G8R8A8_UNORM
HRESULT CDx12NGXVSR::EvaluateFeature(ID3D12Resource* Output, CDx12SyncObject* pDstSyncObj, RECT RectOutput,
                                       ID3D12Resource* Input, CDx12SyncObject* pSrcSyncObj, RECT RectInput,
                                       int Quality)
{
    if (!m_bNGXInitialized)
    {
        return E_FAIL;
    }
    HRESULT hr = S_OK;
    D3D12_RESOURCE_DESC inDesc;
    D3D12_RESOURCE_DESC outDesc;
    // check formats
    {
        // check input is DXGI_FORMAT_R8G8B8A8_UNORM or DXGI_FORMAT_B8G8R8A8_UNORM
        inDesc = Input->GetDesc();
        if (inDesc.Format != DXGI_FORMAT_R8G8B8A8_UNORM && inDesc.Format != DXGI_FORMAT_B8G8R8A8_UNORM)
        {
            return E_INVALIDARG;
        }
        // verify input rect is within range
        if (   RectInput.left < 0 || RectInput.left >= RectInput.right || RectInput.right  >(LONG)inDesc.Width
            || RectInput.top  < 0 || RectInput.top >= RectInput.bottom || RectInput.bottom >(LONG)inDesc.Height)
        {
            return E_INVALIDARG;
        }
        // check output is DXGI_FORMAT_R8G8B8A8_UNORM or DXGI_FORMAT_B8G8R8A8_UNORM
        outDesc = Output->GetDesc();
        if (outDesc.Format != DXGI_FORMAT_R8G8B8A8_UNORM && outDesc.Format != DXGI_FORMAT_B8G8R8A8_UNORM)
        {
            return E_INVALIDARG;
        }
        // verify output rect is within range
        if (   RectOutput.left < 0 || RectOutput.left >= RectOutput.right || RectOutput.right  >(LONG)outDesc.Width
            || RectOutput.top  < 0 || RectOutput.top >= RectOutput.bottom || RectOutput.bottom >(LONG)outDesc.Height)
        {
            return E_INVALIDARG;
        }

        // The NGX dst surface must be created with ALLOW_UNORDERED_ACCESS, which swap buffers are not.
        // check for UNORDERED_ACCESS
        m_bNGXInitializedDstTmp = !(outDesc.Flags & D3D12_RESOURCE_FLAG_ALLOW_UNORDERED_ACCESS);

        // verify DstTmp matches dest surface so copyRegion works
        if (m_bNGXInitializedDstTmp && (!m_pDstTmpNGX || outDesc.Width != m_uDstTmpWidth || outDesc.Height != m_uDstTmpHeight))
        {
            SafeRelease(m_pDstTmpNGX);
            m_uDstTmpWidth                      = outDesc.Width;
            m_uDstTmpHeight                     = outDesc.Height;

            D3D12_RESOURCE_DESC textureDesc     = {};
            textureDesc.Dimension               = D3D12_RESOURCE_DIMENSION_TEXTURE2D;
            textureDesc.Alignment               = 0;
            textureDesc.Width                   = (UINT)m_uDstTmpWidth;
            textureDesc.Height                  = (UINT)m_uDstTmpHeight;
            textureDesc.DepthOrArraySize        = 1;
            textureDesc.MipLevels               = 0;
            textureDesc.Format                  = outDesc.Format;
            textureDesc.SampleDesc.Count        = 1;
            textureDesc.SampleDesc.Quality      = 0;
            textureDesc.Layout                  = D3D12_TEXTURE_LAYOUT_UNKNOWN;
            textureDesc.Flags                   = D3D12_RESOURCE_FLAG_ALLOW_UNORDERED_ACCESS;

            D3D12_HEAP_PROPERTIES heapProp      = {};
            heapProp.Type                       = D3D12_HEAP_TYPE_DEFAULT;
            heapProp.CPUPageProperty            = D3D12_CPU_PAGE_PROPERTY_UNKNOWN;
            heapProp.MemoryPoolPreference       = D3D12_MEMORY_POOL_UNKNOWN;
            heapProp.CreationNodeMask           = m_uGPUNodeMask;
            heapProp.VisibleNodeMask            = m_uGPUVisibleNodeMask;

            D3D12_HEAP_FLAGS heapFlags          = D3D12_HEAP_FLAG_NONE;
            D3D12_RESOURCE_STATES initState     = D3D12_RESOURCE_STATE_COMMON;
            hr = m_D3D12Device->CreateCommittedResource( &heapProp,
                                                        heapFlags,
                                                        &textureDesc,
                                                        initState,
                                                        nullptr,
                                                        IID_PPV_ARGS(&m_pDstTmpNGX));
            if (FAILED(hr)) return hr;
        }
    }

    // NGX output rect, typically 0,0,swapWidth,swapHeight
    m_NGXDstRect = RectOutput;
    // NGX input rect,  typically 0,0,decodeWidth,decodeHeight
    m_NGXSrcRect = RectInput;

    // make sure the commandQ and allocator are ready
    m_CmdQSyncObjNGX->SignalFence(m_commandQueueNGX);
    m_CmdQSyncObjNGX->WaitForCPUFence();

    // add a fence in the command queue to wait for the src to be ready
    pSrcSyncObj->WaitForFence(m_commandQueueNGX);
    // add a fence in the command queue to wait for the dst to be ready
    pDstSyncObj->WaitForFence(m_commandQueueNGX);


    // clear the allocator and command list for evaluate
    hr = m_commandAllocatorNGX->Reset();
    if (FAILED(hr)) return hr;
    hr = m_commandListNGX->Reset(m_commandAllocatorNGX, nullptr);
    if (FAILED(hr)) return hr;

    // setup VSR params
    NVSDK_NGX_D3D12_VSR_Eval_Params D3D12VsrEvalParams = {};
    D3D12VsrEvalParams.pInput                   = Input;
    D3D12VsrEvalParams.pOutput                  = m_bNGXInitializedDstTmp ? m_pDstTmpNGX : Output;
    D3D12VsrEvalParams.InputSubrectBase.X       = RectInput.left;
    D3D12VsrEvalParams.InputSubrectBase.Y       = RectInput.top;
    D3D12VsrEvalParams.InputSubrectSize.Width   = RectInput.right - RectInput.left;
    D3D12VsrEvalParams.InputSubrectSize.Height  = RectInput.bottom - RectInput.top;
    D3D12VsrEvalParams.OutputSubrectBase.X      = RectOutput.left;
    D3D12VsrEvalParams.OutputSubrectBase.Y      = RectOutput.top;
    D3D12VsrEvalParams.OutputSubrectSize.Width  = RectOutput.right - RectOutput.left;
    D3D12VsrEvalParams.OutputSubrectSize.Height = RectOutput.bottom - RectOutput.top;
    D3D12VsrEvalParams.QualityLevel             = (NVSDK_NGX_VSR_QualityLevel) Quality;


    // evaluate VSR
    NVSDK_NGX_Result ResultVSR = NGX_D3D12_EVALUATE_VSR_EXT(m_commandListNGX, m_VSRFeature, m_ngxParameters, &D3D12VsrEvalParams);
    if (NVSDK_NGX_FAILED(ResultVSR))
    {
        return E_FAIL;
    }

    // check if need to transfer to actual dst
    if (m_bNGXInitializedDstTmp)
    {
        D3D12_TEXTURE_COPY_LOCATION Dst = {};
        D3D12_TEXTURE_COPY_LOCATION Src = {};
        Dst.pResource = Output;
        Src.pResource = m_pDstTmpNGX;
        m_commandListNGX->CopyTextureRegion(&Dst, 0, 0, 0, &Src, nullptr);
    }

    // close the command list
    hr = m_commandListNGX->Close();
    if (FAILED(hr)) return hr;
    // execute the command list
    {
        ID3D12CommandList* ppCommandListsNGX[] = { m_commandListNGX };
        m_commandQueueNGX->ExecuteCommandLists(_countof(ppCommandListsNGX), ppCommandListsNGX);

    }
    // signal completion of using src, dst
    pSrcSyncObj->SignalFence(m_commandQueueNGX);
    pDstSyncObj->SignalFence(m_commandQueueNGX);

    return hr;
}

// ReleaseFeature needs to be called before the class destructor 
// as the runtime destructor will delete a critical section used in NVSDK_NGX_D3D12_Shutdown1
void CDx12NGXVSR::ReleaseFeature()
{
    // make sure the commandQueue is done
    if (m_CmdQSyncObjNGX && m_commandQueueNGX)
    {
        m_CmdQSyncObjNGX->SignalFence(m_commandQueueNGX);
        m_CmdQSyncObjNGX->WaitForCPUFence();
    }
    if (m_bNGXInitialized)
    {
        NVSDK_NGX_D3D12_ReleaseFeature(m_VSRFeature);
        m_VSRFeature = nullptr;
        NVSDK_NGX_D3D12_Shutdown1(m_D3D12Device);
        NVSDK_NGX_D3D12_DestroyParameters(m_ngxParameters);

        m_bNGXInitialized = false;
    }
    SafeDelete(m_CmdQSyncObjNGX);
    SafeRelease(m_pDstTmpNGX);
    SafeRelease(m_commandQueueNGX);
    SafeRelease(m_commandListNGX);
    SafeRelease(m_commandAllocatorNGX);
    SafeRelease(m_D3D12Device);
}


// example of a DX12 Sync Object
#if 0

class CDx12SyncObject
{
public:
    CDx12SyncObject(ID3D12Device* pD3D12Device);
    ~CDx12SyncObject();

    UINT64 SignalFence(ID3D12CommandQueue* pCmdQ);
    void WaitForFence(ID3D12CommandQueue* pCmdQ);
    void WaitForCPUFence();

    // these are for checking if an older value has finished
    void WaitForFenceValue(ID3D12CommandQueue* pCmdQ, UINT64 waitValue);
    void WaitForCPUFenceValue(UINT64 waitValue);

private:
    // Synchronization objects.
    ID3D12Fence* m_fence = nullptr;
    UINT64                          m_fenceValue = 0;
    UINT64                          m_fenceSignaledValue = 0;
    HANDLE                          m_fenceEvent = nullptr;
};


//////////////////////////////////////////////////////////////////////////////////////
CDx12SyncObject::CDx12SyncObject(ID3D12Device* pD3D12Device)
{
    pD3D12Device->CreateFence(m_fenceValue++, D3D12_FENCE_FLAG_NONE, IID_PPV_ARGS(&m_fence));
}

//////////////////////////////////////////////////////////////////////////////////////
CDx12SyncObject::~CDx12SyncObject()
{
    SafeRelease(m_fence);
    SafeCloseHandle(m_fenceEvent);
}

//////////////////////////////////////////////////////////////////////////////////////
UINT64 CDx12SyncObject::SignalFence(ID3D12CommandQueue* pCmdQ)
{
    // Signal and increment the fence value.
    m_fenceSignaledValue = m_fenceValue++;
    pCmdQ->Signal(m_fence, m_fenceSignaledValue);
    return m_fenceSignaledValue;
}


//////////////////////////////////////////////////////////////////////////////////////
void CDx12SyncObject::WaitForFence(ID3D12CommandQueue* pCmdQ)
{
    if (m_fence->GetCompletedValue() < m_fenceSignaledValue)
    {
        pCmdQ->Wait(m_fence, m_fenceSignaledValue);
    }
}


//////////////////////////////////////////////////////////////////////////////////////
void CDx12SyncObject::WaitForCPUFence()
{
    if (m_fence->GetCompletedValue() < m_fenceSignaledValue)
    {
        if (!m_fenceEvent)
            m_fenceEvent = CreateEvent(nullptr, FALSE, FALSE, nullptr);
        if (m_fenceEvent)
        {
            m_fence->SetEventOnCompletion(m_fenceSignaledValue, m_fenceEvent);
            WaitForSingleObject(m_fenceEvent, INFINITE);
        }
    }
}

#endif
