//----------------------------------------------------------------------------------
// File:        CNGXDX12TrueHDR.cpp
// SDK Version: 1.0.2
//
// SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: LicenseRef-NvidiaProprietary
//
// NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
// property and proprietary rights in and to this material, related
// documentation and any modifications thereto. Any use, reproduction,
// disclosure or distribution of this material and related documentation
// without an express license agreement from NVIDIA CORPORATION or
// its affiliates is strictly prohibited.
//
//----------------------------------------------------------------------------------

///////////////////////////////////////////
// CNGXDX12TrueHDR.cpp
// This is a wrapper class to simplify adding NGX TrueHDR to a DX12 app
// see .h for info

#include <d3d12.h>

///////////////////////////////////
#define NGX_CLASS_USE
#include "CDx12NGXTrueHDR.h"

// add sync class
#include "CDx12Sync.h"

HRESULT CDx12NGXTrueHDR::CreateFeature(ID3D12Device* pD3DDevice, UINT uGPUNodeMask, UINT uGPUVisibleNodeMask)
{
    HRESULT hr = S_OK;
    // default to false until creation is done
    m_bNGXInitialized   = false;

    m_D3D12Device = pD3DDevice;
    m_D3D12Device->AddRef();

    m_uGPUNodeMask = uGPUNodeMask;
    m_uGPUVisibleNodeMask = uGPUVisibleNodeMask;
    // used to verify commandQ and allocator are ready
    m_CmdQSyncObjNGX = new CDx12SyncObject(m_D3D12Device);

    // init NGX SDK
    NVSDK_NGX_Result Status = NVSDK_NGX_D3D12_Init(APP_ID, APP_PATH, m_D3D12Device);
    if (NVSDK_NGX_FAILED(Status)) return E_FAIL;

    // Get NGX parameters interface (managed and released by NGX)
    Status = NVSDK_NGX_D3D12_GetCapabilityParameters(&m_ngxParameters);
    if (NVSDK_NGX_FAILED(Status)) return E_FAIL;

    // Now check if TrueHDR is available on the system
    int TrueHDRAvailable = 0;
    NVSDK_NGX_Result ResultTrueHDR = m_ngxParameters->Get(NVSDK_NGX_Parameter_TrueHDR_Available, &TrueHDRAvailable);
    if (!TrueHDRAvailable) return E_FAIL;

    {
        // Describe and create the NGX command queue.
        D3D12_COMMAND_QUEUE_DESC queueDesc = {};
        queueDesc.Flags = D3D12_COMMAND_QUEUE_FLAG_NONE;
        queueDesc.Type = D3D12_COMMAND_LIST_TYPE_DIRECT;

        hr = m_D3D12Device->CreateCommandQueue(&queueDesc, IID_PPV_ARGS(&m_commandQueueNGX));
        if (FAILED(hr)) return hr;
    }

    // create NGX command allocator
    hr = m_D3D12Device->CreateCommandAllocator(D3D12_COMMAND_LIST_TYPE_DIRECT, IID_PPV_ARGS(&m_commandAllocatorNGX));
    if (FAILED(hr)) return hr;

    // create NGX command list
    hr = m_D3D12Device->CreateCommandList(0, D3D12_COMMAND_LIST_TYPE_DIRECT, m_commandAllocatorNGX, nullptr, IID_PPV_ARGS(&m_commandListNGX));
    if (FAILED(hr)) return hr;

    hr = m_commandListNGX->Close();
    if (FAILED(hr)) return hr;

    // make sure the commandQ and allocator are ready
    m_CmdQSyncObjNGX->SignalFence(m_commandQueueNGX);
    m_CmdQSyncObjNGX->WaitForCPUFence();

    // clear allocator and command list for the creation
    hr = m_commandAllocatorNGX->Reset();
    if (FAILED(hr)) return hr;
    hr = m_commandListNGX->Reset(m_commandAllocatorNGX, nullptr);
    if (FAILED(hr)) return hr;

     // Create the TrueHDR feature instance 
    NVSDK_NGX_Feature_Create_Params TrueHDRCreateParams = {};
    ResultTrueHDR = NGX_D3D12_CREATE_TRUEHDR_EXT(m_commandListNGX, uGPUNodeMask, uGPUVisibleNodeMask, &m_TrueHDRFeature, m_ngxParameters, &TrueHDRCreateParams);
    
    hr = m_commandListNGX->Close();
    if (FAILED(hr)) return hr;

    if (NVSDK_NGX_FAILED(ResultTrueHDR))
    {
        return E_FAIL;
    }
    // execute the command list
    {
        ID3D12CommandList* ppCommandLists[] = { m_commandListNGX };
        m_commandQueueNGX->ExecuteCommandLists(_countof(ppCommandLists), ppCommandLists);

        m_bNGXInitialized = true;
    }
    return S_OK;
}

// call to apply TrueHDR from source to dest
// input must be DXGI_FORMAT_R8G8B8A8_UNORM or DXGI_FORMAT_B8G8R8A8_UNORM

HRESULT CDx12NGXTrueHDR::EvaluateFeature(ID3D12Resource* Output, CDx12SyncObject* pDstSyncObj, RECT RectOutput,
                                            ID3D12Resource* Input,  CDx12SyncObject* pSrcSyncObj, RECT RectInput,
                                            UINT Contrast,      // 0 to 200 for HDR Contrast   (default 100)
                                            UINT Saturation,    // 0 to 200 for HDR Saturation (default 100)
                                            UINT MiddleGray,    // 10 to 100 for HDR MiddleGray (default  50)
                                            UINT MaxLuminance)  // 400 to 2000 for Monitor MaxLuminance (default 1000)
{
    if (!m_bNGXInitialized)
    {
        return E_FAIL;
    }
    HRESULT hr = S_OK;
    // check formats
    {
        // check input is DXGI_FORMAT_R8G8B8A8_UNORM or DXGI_FORMAT_B8G8R8A8_UNORM
        D3D12_RESOURCE_DESC desc = Input->GetDesc();
        if (desc.Format != DXGI_FORMAT_R8G8B8A8_UNORM && desc.Format != DXGI_FORMAT_B8G8R8A8_UNORM)
        {
            return E_INVALIDARG;
        }
        // verify input rect is within range
        if (   RectInput.left < 0 || RectInput.left >= RectInput.right  || RectInput.right  > (LONG)desc.Width
            || RectInput.top  < 0 || RectInput.top  >= RectInput.bottom || RectInput.bottom > (LONG)desc.Height)
        {
            return E_INVALIDARG;
        }
        // check output is HDR format DXGI_FORMAT_R10G10B10A2_UNORM or DXGI_FORMAT_R16G16B16A16_FLOAT
        desc = Output->GetDesc();
        if (desc.Format != DXGI_FORMAT_R10G10B10A2_UNORM && desc.Format != DXGI_FORMAT_R16G16B16A16_FLOAT)
        {
            return E_INVALIDARG;
        }
        // verify output rect is within range
        if (   RectOutput.left < 0 || RectOutput.left >= RectOutput.right  || RectOutput.right  > (LONG)desc.Width
            || RectOutput.top  < 0 || RectOutput.top  >= RectOutput.bottom || RectOutput.bottom > (LONG)desc.Height)
        {
            return E_INVALIDARG;
        }

        // The NGX dst surface must be created with ALLOW_UNORDERED_ACCESS, which swap buffers are not.
        // check for UNORDERED_ACCESS
        m_bNGXInitializedDstTmp = !(desc.Flags & D3D12_RESOURCE_FLAG_ALLOW_UNORDERED_ACCESS);

        // verify DstTmp matches dest surface so copyRegion works
        if (m_bNGXInitializedDstTmp && (!m_pDstTmpNGX || desc.Width != m_uDstTmpWidth || desc.Height != m_uDstTmpHeight))
        {
            SafeRelease(m_pDstTmpNGX);
            m_uDstTmpWidth                      = desc.Width;
            m_uDstTmpHeight                     = desc.Height;

            D3D12_RESOURCE_DESC textureDesc     = {};
            textureDesc.Dimension               = D3D12_RESOURCE_DIMENSION_TEXTURE2D;
            textureDesc.Alignment               = 0;
            textureDesc.Width                   = (UINT)m_uDstTmpWidth;
            textureDesc.Height                  = (UINT)m_uDstTmpHeight;
            textureDesc.DepthOrArraySize        = 1;
            textureDesc.MipLevels               = 0;
            textureDesc.Format                  = desc.Format;
            textureDesc.SampleDesc.Count        = 1;
            textureDesc.SampleDesc.Quality      = 0;
            textureDesc.Layout                  = D3D12_TEXTURE_LAYOUT_UNKNOWN;
            textureDesc.Flags                   = D3D12_RESOURCE_FLAG_ALLOW_UNORDERED_ACCESS;

            D3D12_HEAP_PROPERTIES heapProp      = {};
            heapProp.Type                       = D3D12_HEAP_TYPE_DEFAULT;
            heapProp.CPUPageProperty            = D3D12_CPU_PAGE_PROPERTY_UNKNOWN;
            heapProp.MemoryPoolPreference       = D3D12_MEMORY_POOL_UNKNOWN;
            heapProp.CreationNodeMask           = m_uGPUNodeMask;
            heapProp.VisibleNodeMask            = m_uGPUVisibleNodeMask;

            D3D12_HEAP_FLAGS heapFlags          = D3D12_HEAP_FLAG_NONE;
            D3D12_RESOURCE_STATES initState     = D3D12_RESOURCE_STATE_COMMON;
            hr = m_D3D12Device->CreateCommittedResource( &heapProp,
                                                        heapFlags,
                                                        &textureDesc,
                                                        initState,
                                                        nullptr,
                                                        IID_PPV_ARGS(&m_pDstTmpNGX));
            if (FAILED(hr)) return hr;
        }
    }

    // NGX output rect, typically 0,0,swapWidth,swapHeight
    m_NGXDstRect = RectOutput;
    // NGX input rect,  typically 0,0,decodeWidth,decodeHeight
    m_NGXSrcRect = RectInput;

    // make sure the commandQ and allocator are ready
    m_CmdQSyncObjNGX->SignalFence(m_commandQueueNGX);
    m_CmdQSyncObjNGX->WaitForCPUFence();

    // add a fence in the command queue to wait for the src to be ready
    pSrcSyncObj->WaitForFence(m_commandQueueNGX);
    // add a fence in the command queue to wait for the dst to be ready
    pDstSyncObj->WaitForFence(m_commandQueueNGX);


    // clear the allocator and command list for evaluate
    hr = m_commandAllocatorNGX->Reset();
    if (FAILED(hr)) return hr;
    hr = m_commandListNGX->Reset(m_commandAllocatorNGX, nullptr);
    if (FAILED(hr)) return hr;

    // setup TrueHDR params
    NVSDK_NGX_D3D12_TRUEHDR_Eval_Params D3D12TrueHDREvalParams = {};

    D3D12TrueHDREvalParams.pInput                   = Input;
    D3D12TrueHDREvalParams.pOutput                  = m_bNGXInitializedDstTmp ? m_pDstTmpNGX : Output;
    D3D12TrueHDREvalParams.InputSubrectTL.X         = m_NGXSrcRect.left;
    D3D12TrueHDREvalParams.InputSubrectTL.Y         = m_NGXSrcRect.top;
    D3D12TrueHDREvalParams.InputSubrectBR.Width     = m_NGXSrcRect.right;
    D3D12TrueHDREvalParams.InputSubrectBR.Height    = m_NGXSrcRect.bottom;
    D3D12TrueHDREvalParams.OutputSubrectTL.X        = m_NGXDstRect.left;
    D3D12TrueHDREvalParams.OutputSubrectTL.Y        = m_NGXDstRect.top;
    D3D12TrueHDREvalParams.OutputSubrectBR.Width    = m_NGXDstRect.right;
    D3D12TrueHDREvalParams.OutputSubrectBR.Height   = m_NGXDstRect.bottom;
    D3D12TrueHDREvalParams.Contrast                 = Contrast;
    D3D12TrueHDREvalParams.Saturation               = Saturation;
    D3D12TrueHDREvalParams.MiddleGray               = MiddleGray;
    D3D12TrueHDREvalParams.MaxLuminance             = MaxLuminance;


    // evaluate TrueHDR
    NVSDK_NGX_Result ResultTrueHDR = NGX_D3D12_EVALUATE_TRUEHDR_EXT(m_commandListNGX, m_TrueHDRFeature, m_ngxParameters, &D3D12TrueHDREvalParams);
    if (NVSDK_NGX_FAILED(ResultTrueHDR))
    {
        return E_FAIL;
    }

    // check if need to transfer to actual dst
    if (m_bNGXInitializedDstTmp)
    {
        D3D12_TEXTURE_COPY_LOCATION Dst = {};
        D3D12_TEXTURE_COPY_LOCATION Src = {};
        Dst.pResource = Output;
        Src.pResource = m_pDstTmpNGX;
        m_commandListNGX->CopyTextureRegion(&Dst, 0, 0, 0, &Src, nullptr);
    }

    // close the command list
    hr = m_commandListNGX->Close();
    if (FAILED(hr)) return hr;
    // execute the command list
    {
        ID3D12CommandList* ppCommandListsNGX[] = { m_commandListNGX };
        m_commandQueueNGX->ExecuteCommandLists(_countof(ppCommandListsNGX), ppCommandListsNGX);

    }
    // signal completion of using src, dst
    pSrcSyncObj->SignalFence(m_commandQueueNGX);
    pDstSyncObj->SignalFence(m_commandQueueNGX);

    return hr;
}

// ReleaseFeature needs to be called before the class destructor 
// as the runtime destructor will delete a critical section used in NVSDK_NGX_D3D12_Shutdown1
void CDx12NGXTrueHDR::ReleaseFeature()
{
    // make sure the commandQueue is done
    if (m_CmdQSyncObjNGX && m_commandQueueNGX)
    {
        m_CmdQSyncObjNGX->SignalFence(m_commandQueueNGX);
        m_CmdQSyncObjNGX->WaitForCPUFence();
    }
    if (m_bNGXInitialized)
    {
        NVSDK_NGX_D3D12_ReleaseFeature(m_TrueHDRFeature);
        m_TrueHDRFeature = nullptr;
        NVSDK_NGX_D3D12_Shutdown1(m_D3D12Device);
        NVSDK_NGX_D3D12_DestroyParameters(m_ngxParameters);

        m_bNGXInitialized = false;
    }
    SafeDelete(m_CmdQSyncObjNGX);
    SafeRelease(m_pDstTmpNGX);
    SafeRelease(m_commandQueueNGX);
    SafeRelease(m_commandListNGX);
    SafeRelease(m_commandAllocatorNGX);
    SafeRelease(m_D3D12Device);
}


// example of a DX12 Sync Object
#if 0

class CDx12SyncObject
{
public:
    CDx12SyncObject(ID3D12Device* pD3D12Device);
    ~CDx12SyncObject();

    UINT64 SignalFence(ID3D12CommandQueue* pCmdQ);
    void WaitForFence(ID3D12CommandQueue* pCmdQ);
    void WaitForCPUFence();

    // these are for checking if an older value has finished
    void WaitForFenceValue(ID3D12CommandQueue* pCmdQ, UINT64 waitValue);
    void WaitForCPUFenceValue(UINT64 waitValue);

private:
    // Synchronization objects.
    ID3D12Fence* m_fence = nullptr;
    UINT64                          m_fenceValue            = 0;
    UINT64                          m_fenceSignaledValue    = 0;
    HANDLE                          m_fenceEvent            = nullptr;
};


//////////////////////////////////////////////////////////////////////////////////////
CDx12SyncObject::CDx12SyncObject(ID3D12Device* pD3D12Device)
{
    pD3D12Device->CreateFence(m_fenceValue++, D3D12_FENCE_FLAG_NONE, IID_PPV_ARGS(&m_fence));
}

//////////////////////////////////////////////////////////////////////////////////////
CDx12SyncObject::~CDx12SyncObject()
{
    SafeRelease(m_fence);
    SafeCloseHandle(m_fenceEvent);
}

//////////////////////////////////////////////////////////////////////////////////////
UINT64 CDx12SyncObject::SignalFence(ID3D12CommandQueue* pCmdQ)
{
    // Signal and increment the fence value.
    m_fenceSignaledValue = m_fenceValue++;
    pCmdQ->Signal(m_fence, m_fenceSignaledValue);
    return m_fenceSignaledValue;
}


//////////////////////////////////////////////////////////////////////////////////////
void CDx12SyncObject::WaitForFence(ID3D12CommandQueue* pCmdQ)
{
    if (m_fence->GetCompletedValue() < m_fenceSignaledValue)
    {
        pCmdQ->Wait(m_fence, m_fenceSignaledValue);
    }
}


//////////////////////////////////////////////////////////////////////////////////////
void CDx12SyncObject::WaitForCPUFence()
{
    if (m_fence->GetCompletedValue() < m_fenceSignaledValue)
    {
        if (!m_fenceEvent)
            m_fenceEvent = CreateEvent(nullptr, FALSE, FALSE, nullptr);
        if (m_fenceEvent)
        {
            m_fence->SetEventOnCompletion(m_fenceSignaledValue, m_fenceEvent);
            WaitForSingleObject(m_fenceEvent, INFINITE);
        }
    }
}

#endif
