//----------------------------------------------------------------------------------
// File:        rtx_video_api_dx12_impl.cpp
// SDK Version: 1.0.2
//
// SPDX-FileCopyrightText: Copyright (c) 2023-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: LicenseRef-NvidiaProprietary
//
// NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
// property and proprietary rights in and to this material, related
// documentation and any modifications thereto. Any use, reproduction,
// disclosure or distribution of this material and related documentation
// without an express license agreement from NVIDIA CORPORATION or
// its affiliates is strictly prohibited.
//
//----------------------------------------------------------------------------------

/**
*  This sample application demonstrates use of RTX Video SDK
*  by providing an api taking input and output.
*  Inputs must be 8 bit video.
*  Output from VSR is in 8 bit video (h264/hevc supported).
*  Output from THDR is in 10 bit video (hevc/av1 supported).
*  If both are enabled then VSR -> THDR.
*/

#include <nvsdk_ngx_defs.h>
#include <nvsdk_ngx_defs_truehdr.h>
#include <nvsdk_ngx_helpers_truehdr.h>
#include <nvsdk_ngx_defs_vsr.h>
#include <nvsdk_ngx_helpers_vsr.h>

#if defined(NDEBUG)
#pragma comment( lib, "nvsdk_ngx_s.lib" ) // ngx sdk
#else
#pragma comment( lib, "nvsdk_ngx_s_dbg.lib" ) // ngx sdk
#endif

#include <d3d12.h>
#pragma comment( lib, "d3d12" )

#include "rtx_video_api.h"
#include "utils.h"

class dx12_api_impl
{
private:
    ID3D12Device*               m_D3D12Device           = nullptr;
    ID3D12CommandAllocator*     m_commandAllocatorNGX   = nullptr;
    ID3D12CommandQueue*         m_commandQueueNGX       = nullptr;
    ID3D12GraphicsCommandList*  m_commandListNGX        = nullptr;
    UINT                        m_uGPUNodeMask          = 1;
    UINT                        m_uGPUVisibleNodeMask   = 1;

    ID3D12Fence*                m_CmdQFence             = nullptr;
    uint64_t                    m_qwCmdQFenceValue      = 0;

    bool                        m_bNGXInitialized       = false;
    NVSDK_NGX_Parameter*        m_ngxParameters         = nullptr;
    NVSDK_NGX_Handle*           m_TrueHDRFeature        = nullptr;
    NVSDK_NGX_Handle*           m_VSRFeature            = nullptr;

    bool                        m_bSetupDstTmp          = false;
    ID3D12Resource*             m_pDstTmp               = nullptr;
    UINT                        m_uDstTmpWidth          = 0;
    UINT                        m_uDstTmpHeight         = 0;

    bool                        m_bNeedMiddle           = false;
    ID3D12Resource*             m_pMiddle               = nullptr;
    UINT                        m_uMiddleWidth          = 0;
    UINT                        m_uMiddleHeight         = 0;


public:
    API_BOOL create(ID3D12Device* pD3DDevice, uint32_t uGPUNodeMask, uint32_t uGPUVisibleNodeMask, API_BOOL THDREnable, API_BOOL VSREnable);
    API_BOOL evaluate(ID3D12Resource* pInput, ID3D12Resource* pOutput, ID3D12Fence* pInFence, uint64_t& qwInFenceValue, ID3D12Fence* pOutFence, uint64_t& qwOutFenceValue,
                                                  API_RECT inputRect, API_RECT outputRect, API_VSR_Setting* pVSRSetting, API_THDR_Setting* pTHDRSetting);
    void shutdown();
};


API_BOOL dx12_api_impl::create(ID3D12Device* pD3DDevice, uint32_t uGPUNodeMask, uint32_t uGPUVisibleNodeMask, API_BOOL THDREnable, API_BOOL VSREnable)
{
    HRESULT hr = S_OK;
    // default to false until creation is done
    m_bNGXInitialized   = false;

    m_D3D12Device = pD3DDevice;
    m_D3D12Device->AddRef();

    m_uGPUNodeMask = uGPUNodeMask;
    m_uGPUVisibleNodeMask = uGPUVisibleNodeMask;

    // used to verify commandQ and allocator are ready
    m_D3D12Device->CreateFence(m_qwCmdQFenceValue++, D3D12_FENCE_FLAG_NONE, __uuidof(ID3D12Fence), (void**)(&m_CmdQFence));

    {
        // Describe and create the NGX command queue.
        D3D12_COMMAND_QUEUE_DESC queueDesc = {};
        queueDesc.Flags = D3D12_COMMAND_QUEUE_FLAG_NONE;
        queueDesc.Type = D3D12_COMMAND_LIST_TYPE_DIRECT;

        hr = m_D3D12Device->CreateCommandQueue(&queueDesc, IID_PPV_ARGS(&m_commandQueueNGX));
        if (FAILED(hr)) return FALSE;
    }

    // create NGX command allocator
    hr = m_D3D12Device->CreateCommandAllocator(D3D12_COMMAND_LIST_TYPE_DIRECT, IID_PPV_ARGS(&m_commandAllocatorNGX));
    if (FAILED(hr)) return FALSE;

    // create NGX command list
    hr = m_D3D12Device->CreateCommandList(0, D3D12_COMMAND_LIST_TYPE_DIRECT, m_commandAllocatorNGX, nullptr, IID_PPV_ARGS(&m_commandListNGX));
    if (FAILED(hr)) return FALSE;

    hr = m_commandListNGX->Close();
    if (FAILED(hr)) return FALSE;

    // make sure the commandQ and allocator are ready
    m_commandQueueNGX->Signal(m_CmdQFence, ++m_qwCmdQFenceValue);
    m_commandQueueNGX->Wait(m_CmdQFence, m_qwCmdQFenceValue);

    // clear allocator and command list for the creation
    hr = m_commandAllocatorNGX->Reset();
    if (FAILED(hr)) return FALSE;
    hr = m_commandListNGX->Reset(m_commandAllocatorNGX, nullptr);
    if (FAILED(hr)) return FALSE;

    
    // init NGX SDK
    NVSDK_NGX_Result NGX_Status = NVSDK_NGX_D3D12_Init(APP_ID, APP_PATH, pD3DDevice);
    if (NVSDK_NGX_FAILED(NGX_Status)) return FALSE;

    // Get NGX parameters interface (managed and released by NGX)
    NGX_Status = NVSDK_NGX_D3D12_GetCapabilityParameters(&m_ngxParameters);
    if (NVSDK_NGX_FAILED(NGX_Status)) return FALSE;

    if (THDREnable)
    {
        // Check if TrueHDR is available on the system
        int TrueHDRAvailable = 0;
        NGX_Status = m_ngxParameters->Get(NVSDK_NGX_Parameter_TrueHDR_Available, &TrueHDRAvailable);
        if (!TrueHDRAvailable) return FALSE;

         // Create the TrueHDR feature instance 
        NVSDK_NGX_Feature_Create_Params TrueHDRCreateParams = {};
        NGX_Status = NGX_D3D12_CREATE_TRUEHDR_EXT(m_commandListNGX, uGPUNodeMask, uGPUVisibleNodeMask, &m_TrueHDRFeature, m_ngxParameters, &TrueHDRCreateParams);
        if (NVSDK_NGX_FAILED(NGX_Status)) return FALSE;
    }
    if (VSREnable)
    {
        // Check if VSR is available on the system
        int VSRAvailable = 0;
        NGX_Status = m_ngxParameters->Get(NVSDK_NGX_Parameter_VSR_Available, &VSRAvailable);
        if (!VSRAvailable) return FALSE;

         // Create the VSR feature instance 
        NVSDK_NGX_Feature_Create_Params VSRCreateParams = {};
        NGX_Status = NGX_D3D12_CREATE_VSR_EXT(m_commandListNGX, uGPUNodeMask, uGPUVisibleNodeMask, &m_VSRFeature, m_ngxParameters, &VSRCreateParams);
        if (NVSDK_NGX_FAILED(NGX_Status)) return FALSE;
    }

    hr = m_commandListNGX->Close();
    if (FAILED(hr)) return FALSE;

    // execute the command list
    {
        ID3D12CommandList* ppCommandLists[] = { m_commandListNGX };
        m_commandQueueNGX->ExecuteCommandLists(_countof(ppCommandLists), ppCommandLists);
    }

    m_bNeedMiddle = (THDREnable && VSREnable);
    m_bNGXInitialized = true;
    return TRUE;
}

API_BOOL dx12_api_impl::evaluate(ID3D12Resource* pInput, ID3D12Resource* pOutput, ID3D12Fence* pInFence, uint64_t& qwInFenceValue, ID3D12Fence* pOutFence, uint64_t& qwOutFenceValue,
                                 API_RECT inputRect, API_RECT outputRect, API_VSR_Setting* pVSRSetting, API_THDR_Setting* pTHDRSetting)
{
    if (!m_bNGXInitialized)
    {
        return FALSE;
    }

    if (m_TrueHDRFeature && !pTHDRSetting)
    {
        return FALSE;
    }

    if (m_VSRFeature && !pVSRSetting)
    {
        return FALSE;
    }
    
    HRESULT hr = S_OK;
    NVSDK_NGX_Result NGX_Status;

    // check formats
    {
        D3D12_RESOURCE_DESC inDesc = {};
        D3D12_RESOURCE_DESC outDesc = {};
        // check input is DXGI_FORMAT_R8G8B8A8_UNORM or DXGI_FORMAT_B8G8R8A8_UNORM
        inDesc = pInput->GetDesc();
        if (inDesc.Format != DXGI_FORMAT_R8G8B8A8_UNORM && inDesc.Format != DXGI_FORMAT_B8G8R8A8_UNORM)
        {
            return FALSE;
        }
        // verify input rect is within range
        if (   inputRect.left < 0 || inputRect.left >= inputRect.right || inputRect.right  > inDesc.Width
            || inputRect.top  < 0 || inputRect.top >= inputRect.bottom || inputRect.bottom > inDesc.Height)
        {
            return FALSE;
        }
        // check output is DXGI_FORMAT_R8G8B8A8_UNORM or DXGI_FORMAT_B8G8R8A8_UNORM
        outDesc = pOutput->GetDesc();
        if (m_TrueHDRFeature)
        {
            // check output is HDR format DXGI_FORMAT_R10G10B10A2_UNORM or DXGI_FORMAT_R16G16B16A16_FLOAT
            if (outDesc.Format != DXGI_FORMAT_R10G10B10A2_UNORM && outDesc.Format != DXGI_FORMAT_R16G16B16A16_FLOAT)
            {
                return FALSE;
            }
        }
        else if (outDesc.Format != DXGI_FORMAT_R8G8B8A8_UNORM && outDesc.Format != DXGI_FORMAT_B8G8R8A8_UNORM)
        {
            return FALSE;
        }

        // verify output rect is within range
        if (   outputRect.left < 0 || outputRect.left >= outputRect.right || outputRect.right  > outDesc.Width
            || outputRect.top  < 0 || outputRect.top >= outputRect.bottom || outputRect.bottom > outDesc.Height)
        {
            return FALSE;
        }

        // The NGX dst surface must be created with ALLOW_UNORDERED_ACCESS, which swap buffers are not.
        // check for UNORDERED_ACCESS
        m_bSetupDstTmp = !(outDesc.Flags & D3D12_RESOURCE_FLAG_ALLOW_UNORDERED_ACCESS);

        // verify DstTmp matches dest surface so copyRegion works
        if (m_bSetupDstTmp && (!m_pDstTmp || outDesc.Width != m_uDstTmpWidth || outDesc.Height != m_uDstTmpHeight))
        {
            SafeRelease(m_pDstTmp);
            m_uDstTmpWidth                      = (uint32_t)outDesc.Width;
            m_uDstTmpHeight                     = (uint32_t)outDesc.Height;

            D3D12_RESOURCE_DESC textureDesc     = {};
            textureDesc.Dimension               = D3D12_RESOURCE_DIMENSION_TEXTURE2D;
            textureDesc.Alignment               = 0;
            textureDesc.Width                   = m_uDstTmpWidth;
            textureDesc.Height                  = m_uDstTmpHeight;
            textureDesc.DepthOrArraySize        = 1;
            textureDesc.MipLevels               = 0;
            textureDesc.Format                  = outDesc.Format;
            textureDesc.SampleDesc.Count        = 1;
            textureDesc.SampleDesc.Quality      = 0;
            textureDesc.Layout                  = D3D12_TEXTURE_LAYOUT_UNKNOWN;
            textureDesc.Flags                   = D3D12_RESOURCE_FLAG_ALLOW_UNORDERED_ACCESS;

            D3D12_HEAP_PROPERTIES heapProp      = {};
            heapProp.Type                       = D3D12_HEAP_TYPE_DEFAULT;
            heapProp.CPUPageProperty            = D3D12_CPU_PAGE_PROPERTY_UNKNOWN;
            heapProp.MemoryPoolPreference       = D3D12_MEMORY_POOL_UNKNOWN;
            heapProp.CreationNodeMask           = m_uGPUNodeMask;
            heapProp.VisibleNodeMask            = m_uGPUVisibleNodeMask;

            D3D12_HEAP_FLAGS heapFlags          = D3D12_HEAP_FLAG_NONE;
            D3D12_RESOURCE_STATES initState     = D3D12_RESOURCE_STATE_COMMON;
            hr = m_D3D12Device->CreateCommittedResource( &heapProp,
                                                        heapFlags,
                                                        &textureDesc,
                                                        initState,
                                                        nullptr,
                                                        IID_PPV_ARGS(&m_pDstTmp));
            if (FAILED(hr)) return FALSE;
        }

        if (m_bNeedMiddle && (!m_pMiddle || outDesc.Width != m_uMiddleWidth || outDesc.Height != m_uMiddleHeight))
        {
            SafeRelease(m_pMiddle);
            m_uMiddleWidth                          = (uint32_t)outDesc.Width;
            m_uMiddleHeight                         = (uint32_t)outDesc.Height;

            D3D12_RESOURCE_DESC textureDesc     = {};
            textureDesc.Dimension               = D3D12_RESOURCE_DIMENSION_TEXTURE2D;
            textureDesc.Alignment               = 0;
            textureDesc.Width                   = m_uMiddleWidth;
            textureDesc.Height                  = m_uMiddleHeight;
            textureDesc.DepthOrArraySize        = 1;
            textureDesc.MipLevels               = 0;
            textureDesc.Format                  = DXGI_FORMAT_R8G8B8A8_UNORM;
            textureDesc.SampleDesc.Count        = 1;
            textureDesc.SampleDesc.Quality      = 0;
            textureDesc.Layout                  = D3D12_TEXTURE_LAYOUT_UNKNOWN;
            textureDesc.Flags                   = D3D12_RESOURCE_FLAG_ALLOW_UNORDERED_ACCESS;

            D3D12_HEAP_PROPERTIES heapProp      = {};
            heapProp.Type                       = D3D12_HEAP_TYPE_DEFAULT;
            heapProp.CPUPageProperty            = D3D12_CPU_PAGE_PROPERTY_UNKNOWN;
            heapProp.MemoryPoolPreference       = D3D12_MEMORY_POOL_UNKNOWN;
            heapProp.CreationNodeMask           = m_uGPUNodeMask;
            heapProp.VisibleNodeMask            = m_uGPUVisibleNodeMask;

            D3D12_HEAP_FLAGS heapFlags          = D3D12_HEAP_FLAG_NONE;
            D3D12_RESOURCE_STATES initState     = D3D12_RESOURCE_STATE_COMMON;
            hr = m_D3D12Device->CreateCommittedResource( &heapProp,
                                                        heapFlags,
                                                        &textureDesc,
                                                        initState,
                                                        nullptr,
                                                        IID_PPV_ARGS(&m_pMiddle));
            if (FAILED(hr)) return FALSE;
        }
    }

    // make sure the commandQ and allocator are ready
    m_commandQueueNGX->Signal(m_CmdQFence, ++m_qwCmdQFenceValue);
    m_commandQueueNGX->Wait(m_CmdQFence, m_qwCmdQFenceValue);

    // add a fence in the command queue to wait for the src to be ready
    m_commandQueueNGX->Wait(pInFence, qwInFenceValue);
    // add a fence in the command queue to wait for the dst to be ready
    m_commandQueueNGX->Wait(pOutFence, qwOutFenceValue);


    // clear the allocator and command list for evaluate
    hr = m_commandAllocatorNGX->Reset();
    if (FAILED(hr)) return FALSE;
    hr = m_commandListNGX->Reset(m_commandAllocatorNGX, nullptr);
    if (FAILED(hr)) return FALSE;

    if (m_VSRFeature)
    {
        // setup VSR params
        NVSDK_NGX_D3D12_VSR_Eval_Params D3D12VsrEvalParams = {};
        D3D12VsrEvalParams.pInput                       = pInput;
        D3D12VsrEvalParams.pOutput                      = m_bNeedMiddle ? m_pMiddle : (m_bSetupDstTmp ? m_pDstTmp : pOutput);
        D3D12VsrEvalParams.InputSubrectBase.X           = inputRect.left;
        D3D12VsrEvalParams.InputSubrectBase.Y           = inputRect.top;
        D3D12VsrEvalParams.InputSubrectSize.Width       = inputRect.right - inputRect.left;
        D3D12VsrEvalParams.InputSubrectSize.Height      = inputRect.bottom - inputRect.top;
        D3D12VsrEvalParams.OutputSubrectBase.X          = outputRect.left;
        D3D12VsrEvalParams.OutputSubrectBase.Y          = outputRect.top;
        D3D12VsrEvalParams.OutputSubrectSize.Width      = outputRect.right - outputRect.left;
        D3D12VsrEvalParams.OutputSubrectSize.Height     = outputRect.bottom - outputRect.top;
        D3D12VsrEvalParams.QualityLevel                 = (NVSDK_NGX_VSR_QualityLevel)pVSRSetting->QualityLevel;

        // evaluate VSR
        NGX_Status = NGX_D3D12_EVALUATE_VSR_EXT(m_commandListNGX, m_VSRFeature, m_ngxParameters, &D3D12VsrEvalParams);
        if (NVSDK_NGX_FAILED(NGX_Status)) return FALSE;
    }
    if (m_TrueHDRFeature)
    {
        NVSDK_NGX_D3D12_TRUEHDR_Eval_Params D3D12TrueHDREvalParams = {};

        D3D12TrueHDREvalParams.pInput                   = m_bNeedMiddle ? m_pMiddle : pInput;
        D3D12TrueHDREvalParams.pOutput                  = m_bSetupDstTmp ? m_pDstTmp : pOutput;
        D3D12TrueHDREvalParams.InputSubrectTL.X         = m_bNeedMiddle ? outputRect.left   : inputRect.left;
        D3D12TrueHDREvalParams.InputSubrectTL.Y         = m_bNeedMiddle ? outputRect.top    : inputRect.top;
        D3D12TrueHDREvalParams.InputSubrectBR.Width     = m_bNeedMiddle ? outputRect.right  : inputRect.right;
        D3D12TrueHDREvalParams.InputSubrectBR.Height    = m_bNeedMiddle ? outputRect.bottom : inputRect.bottom;
        D3D12TrueHDREvalParams.OutputSubrectTL.X        = outputRect.left;
        D3D12TrueHDREvalParams.OutputSubrectTL.Y        = outputRect.top;
        D3D12TrueHDREvalParams.OutputSubrectBR.Width    = outputRect.right;
        D3D12TrueHDREvalParams.OutputSubrectBR.Height   = outputRect.bottom;
        D3D12TrueHDREvalParams.Contrast                 = pTHDRSetting->Contrast;
        D3D12TrueHDREvalParams.Saturation               = pTHDRSetting->Saturation;
        D3D12TrueHDREvalParams.MiddleGray               = pTHDRSetting->MiddleGray;
        D3D12TrueHDREvalParams.MaxLuminance             = pTHDRSetting->MaxLuminance;
        NGX_Status = NGX_D3D12_EVALUATE_TRUEHDR_EXT(m_commandListNGX, m_TrueHDRFeature, m_ngxParameters, &D3D12TrueHDREvalParams);
        if (NVSDK_NGX_FAILED(NGX_Status)) return FALSE;
    }

    // check if need to transfer to actual dst
    if (m_bSetupDstTmp)
    {
        D3D12_TEXTURE_COPY_LOCATION Dst = {};
        D3D12_TEXTURE_COPY_LOCATION Src = {};
        Dst.pResource = pOutput;
        Src.pResource = m_pDstTmp;
        m_commandListNGX->CopyTextureRegion(&Dst, 0, 0, 0, &Src, nullptr);
    }

    // close the command list
    hr = m_commandListNGX->Close();
    if (FAILED(hr)) return FALSE;
    // execute the command list
    {
        ID3D12CommandList* ppCommandListsNGX[] = { m_commandListNGX };
        m_commandQueueNGX->ExecuteCommandLists(_countof(ppCommandListsNGX), ppCommandListsNGX);

    }
    // signal completion of using src, dst
    m_commandQueueNGX->Signal(pInFence, ++qwInFenceValue);
    m_commandQueueNGX->Signal(pOutFence, ++qwOutFenceValue);

    return TRUE;
}

void dx12_api_impl::shutdown()
{
    // make sure the commandQueue is done
    if (m_CmdQFence && m_commandQueueNGX)
    {
        m_commandQueueNGX->Signal(m_CmdQFence, ++m_qwCmdQFenceValue);
        m_commandQueueNGX->Wait(m_CmdQFence, m_qwCmdQFenceValue);
    }
    if (m_bNGXInitialized)
    {
        if (m_VSRFeature)
        {
            NVSDK_NGX_D3D12_ReleaseFeature(m_VSRFeature);
            m_VSRFeature = nullptr;
        }
        if (m_TrueHDRFeature)
        {
            NVSDK_NGX_D3D12_ReleaseFeature(m_TrueHDRFeature);
            m_TrueHDRFeature = nullptr;
        }
        NVSDK_NGX_D3D12_Shutdown1(m_D3D12Device);
        NVSDK_NGX_D3D12_DestroyParameters(m_ngxParameters);

        m_bNGXInitialized = false;
    }
    SafeRelease(m_pDstTmp);
    SafeRelease(m_pMiddle);
    SafeRelease(m_CmdQFence);
    SafeRelease(m_commandQueueNGX);
    SafeRelease(m_commandListNGX);
    SafeRelease(m_commandAllocatorNGX);
    SafeRelease(m_D3D12Device);
}


////////////////////////////////////////////////////////////////////////////
////////////////////////////////////////////////////////////////////////////

dx12_api_impl* p_dx12_api_impl = nullptr;

#if !defined(_WIN32)
__attribute__ ((visibility("default")))
#endif
API_BOOL rtx_video_api_dx12_create(ID3D12Device* pD3DDevice, uint32_t uGPUNodeMask, uint32_t uGPUVisibleNodeMask, API_BOOL THDREnable, API_BOOL VSREnable)
{
    if (!p_dx12_api_impl)
    {
        p_dx12_api_impl = new dx12_api_impl;
    }
    if (!p_dx12_api_impl) return false;
    return p_dx12_api_impl->create(pD3DDevice, uGPUNodeMask, uGPUVisibleNodeMask, THDREnable, VSREnable);
}

#if !defined(_WIN32)
__attribute__((visibility("default")))
#endif
API_BOOL rtx_video_api_dx12_evaluate(ID3D12Resource* pInput, ID3D12Resource* pOutput, ID3D12Fence* pInFence, uint64_t& qwInFenceValue, ID3D12Fence* pOutFence, uint64_t& qwOutFenceValue,
                                                  API_RECT inputRect, API_RECT outputRect, API_VSR_Setting* pVSRSetting, API_THDR_Setting* pTHDRSetting)
{
    if (!p_dx12_api_impl) return false;
    return p_dx12_api_impl->evaluate(pInput, pOutput, pInFence, qwInFenceValue, pOutFence, qwOutFenceValue, inputRect, outputRect, pVSRSetting, pTHDRSetting);
}

#if !defined(_WIN32)
__attribute__((visibility("default")))
#endif
void rtx_video_api_dx12_shutdown()
{
    if (p_dx12_api_impl)
    {
        p_dx12_api_impl->shutdown();
        delete p_dx12_api_impl;
        p_dx12_api_impl = nullptr;
    }
}
