/**
 * Copyright 2022-2025 NVIDIA Corporation.  All rights reserved.
 *
 * Please refer to the NVIDIA end user license agreement (EULA) associated
 * with this source code for terms and conditions that govern your use of
 * this software. Any use, reproduction, disclosure, or distribution of
 * this software and related documentation outside the terms of the EULA
 * is strictly prohibited.
 *
 */

////////////////////////////////////////////////////////////////////////////////

#ifndef HELPER_CUPTI_ACTIVITY_H_
#define HELPER_CUPTI_ACTIVITY_H_

#pragma once

// CUPTI headers
#include <cupti.h>
#include <helper_cupti_activity_structs.h>

// CUPTI buffer size 8 MB
#define BUF_SIZE (8 * 1024 * 1024)

// 8-byte alignment for the buffers
#define ALIGN_SIZE (8)
#define ALIGN_BUFFER(buffer, align)                                                 \
  (((uintptr_t) (buffer) & ((align)-1)) ? ((buffer) + (align) - ((uintptr_t) (buffer) & ((align)-1))) : (buffer))

typedef uint64_t HashMapKey;

#define MAX_ACTIVITY_KINDS CUPTI_ACTIVITY_KIND_COUNT

// Data structures

// Global state
typedef struct GlobalState_st
{
    CUpti_SubscriberHandle subscriberHandle;                         // CUPTI subcriber handle to subcribe to CUPTI callbacks.
    size_t activityBufferSize;                                       // CUPTI activity buffer size.
    FILE   *pOutputFile;                                             // File handle to print the CUPTI activity records. default = stdout.
    void   *pUserData;                                               // User data used to initialize CUPTI trace. Refer UserData structure.
    uint64_t buffersRequested;                                       // Requested buffers by CUPTI.
    uint64_t buffersCompleted;                                       // Completed buffers by received from CUPTI.
} GlobalState;

// User data provided by the application using InitCuptiTrace()
// User need to allocate memory for this structure in the sample.
// Set the options according to the workloads requirement.
typedef struct UserData_st
{
    size_t  activityBufferSize;                                      // CUPTI activity buffer size.
    size_t  deviceBufferSize;                                        // CUPTI device buffer size.
    uint8_t flushAtStreamSync;                                       // Flush CUPTI activity records at stream syncronization.
    uint8_t flushAtCtxSync;                                          // Flush CUPTI activity records at context syncronization.
    uint8_t printCallbacks;                                          // Print callbacks enabled in CUPTI.
    uint8_t printActivityRecords;                                    // Print CUPTI activity records.
    uint8_t skipCuptiSubscription;                                   // Check if the user application wants to skip subscription in CUPTI.
    void    (*pPostProcessActivityRecords)(CUpti_Activity *pRecord); // Provide function pointer in the user application for CUPTI records for post processing.
} UserData;

// Global variables
static GlobalState globals = { 0 };

static const char *
GetCallbackString(CUpti_CallbackDomain domain,
                  uint32_t cbid)
{
    switch (domain)
    {
        case CUPTI_CB_DOMAIN_INVALID:
            return "INVALID";
        case CUPTI_CB_DOMAIN_DRIVER_API:
        case CUPTI_CB_DOMAIN_RUNTIME_API:
        {
            const char *name = NULL;
            // API cuptiGetCallbackName is available for DRIVER and RUNTIME domains only
            CUPTI_API_CALL(cuptiGetCallbackName(domain, cbid, &name));
            return name;
        }
        case CUPTI_CB_DOMAIN_RESOURCE:
        {
            return GetCallbackIdResourceString((CUpti_CallbackIdResource)cbid);
        }
        case CUPTI_CB_DOMAIN_SYNCHRONIZE:
        {
            return GetCallbackIdSyncString((CUpti_CallbackIdSync)cbid);
        }
        case CUPTI_CB_DOMAIN_NVTX:
        {
            // return Getnvtx_api_trace_cbid((CUpti_nvtx_api_trace_cbid)cbid);
        }
        case CUPTI_CB_DOMAIN_STATE:
        {
            return GetCallbackIdStateString((CUpti_CallbackIdState)cbid);
        }
        default:
            return "<unknown>";
    }
}

static void
PrintActivityBuffer(
    uint8_t *pBuffer,
    size_t validBytes,
    FILE *pFileHandle,
    void *pUserData)
{
    CUpti_Activity *pRecord = NULL;
    CUptiResult status = CUPTI_SUCCESS;

    do
    {
        status = cuptiActivityGetNextRecord(pBuffer, validBytes, &pRecord);
        if (status == CUPTI_SUCCESS)
        {
            if (!pUserData ||
                (pUserData && ((UserData *)pUserData)->printActivityRecords))
            {
                PrintActivity(pRecord, pFileHandle);
            }

            if (pUserData &&
                ((UserData *)pUserData)->pPostProcessActivityRecords)
            {
                ((UserData *)pUserData)->pPostProcessActivityRecords(pRecord);
            }
        }
        else if (status == CUPTI_ERROR_MAX_LIMIT_REACHED)
        {
            break;
        }
        else if (status == CUPTI_ERROR_INVALID_KIND)
        {
            break;
        }
        else
        {
            CUPTI_API_CALL(status);
        }
    } while (1);
}

// Buffer Management Functions
static void CUPTIAPI
BufferRequested(
    uint8_t **ppBuffer,
    size_t *pSize,
    size_t *pMaxNumRecords)
{
    uint8_t *pBuffer = (uint8_t *) malloc(globals.activityBufferSize + ALIGN_SIZE);
    MEMORY_ALLOCATION_CALL(pBuffer);

    *pSize = globals.activityBufferSize;
    *ppBuffer = ALIGN_BUFFER(pBuffer, ALIGN_SIZE);
    *pMaxNumRecords = 0;

    globals.buffersRequested++;
}

static void CUPTIAPI
BufferCompleted(
    CUcontext context,
    uint32_t streamId,
    uint8_t *pBuffer,
    size_t size,
    size_t validSize)
{
    if (validSize > 0)
    {
        FILE *pOutputFile = globals.pOutputFile;
        if (!pOutputFile)
        {
            pOutputFile = stdout;
        }

        PrintActivityBuffer(pBuffer, validSize, pOutputFile, globals.pUserData);
    }

    globals.buffersCompleted++;
    free(pBuffer);
}

// CUPTI callback functions
static void
HandleSyncronizationCallbacks(
    CUpti_CallbackId callbackId,
    const CUpti_SynchronizeData *pSynchronizeData,
    void *pUserData)
{
    // Flush the CUPTI activity records buffer on context synchronization
    if (callbackId == CUPTI_CBID_SYNCHRONIZE_CONTEXT_SYNCHRONIZED &&
        ((UserData *)pUserData)->flushAtCtxSync)
    {
        CUPTI_API_CALL_VERBOSE(cuptiActivityFlushAll(0));
    }
    // Flush the CUPTI activity records buffer on stream synchronization
    else if (callbackId == CUPTI_CBID_SYNCHRONIZE_STREAM_SYNCHRONIZED &&
            ((UserData *)pUserData)->flushAtStreamSync)
    {
        uint32_t streamId = 0;
        CUPTI_API_CALL_VERBOSE(cuptiGetStreamId(pSynchronizeData->context, pSynchronizeData->stream, &streamId));
        CUPTI_API_CALL_VERBOSE(cuptiActivityFlushAll(0));
    }
}

static void
HandleDomainStateCallback(
    CUpti_CallbackId callbackId,
    const CUpti_StateData *pStateData)
{
    switch (callbackId)
    {
        case CUPTI_CBID_STATE_FATAL_ERROR:
        {
            const char *errorString = NULL;
            cuptiGetResultString(pStateData->notification.result, &errorString);

            fprintf(globals.pOutputFile, "\nCUPTI encountered fatal error: %s\n", errorString);
            fprintf(globals.pOutputFile, "Error message: %s\n", pStateData->notification.message);

            // Exiting the application if fatal error encountered in CUPTI
            // If there is a CUPTI fatal error, it means CUPTI has stopped profiling the application.
            exit(EXIT_FAILURE);
        }
        default:
            break;
    }
}

static void CUPTIAPI
CuptiCallbackHandler(
    void *pUserData,
    CUpti_CallbackDomain domain,
    CUpti_CallbackId callbackId,
    const void *pCallbackData)
{
    CUPTI_API_CALL(cuptiGetLastError());

    if (((UserData *)pUserData)->printCallbacks &&
        globals.pOutputFile != NULL)
    {
        fprintf(globals.pOutputFile, "CUPTI Callback: Domain %d CbId %d\n", domain, callbackId);
        fflush(globals.pOutputFile);
    }

    const CUpti_CallbackData *pCallabckInfo = (CUpti_CallbackData *)pCallbackData;

    switch (domain)
    {
        case CUPTI_CB_DOMAIN_STATE:
            HandleDomainStateCallback(callbackId, (CUpti_StateData *)pCallbackData);
            break;
        case CUPTI_CB_DOMAIN_RUNTIME_API:
            switch (callbackId)
            {
                case CUPTI_RUNTIME_TRACE_CBID_cudaDeviceReset_v3020:
                    if (pCallabckInfo->callbackSite == CUPTI_API_ENTER)
                    {
                        CUPTI_API_CALL_VERBOSE(cuptiActivityFlushAll(0));
                    }
                    break;
                default:
                    break;
            }
            break;
        case CUPTI_CB_DOMAIN_SYNCHRONIZE:
            HandleSyncronizationCallbacks(callbackId, (CUpti_SynchronizeData *)pCallbackData, pUserData);
            break;
        default:
            break;
    }
}

// CUPTI Trace Setup
static void
InitCuptiTrace(
    void *pUserData,
    void *pTraceCallback,
    FILE *pFileHandle)
{
    if (!pUserData)
    {
        std::cerr << "Invalid parameter pUserData.\n";
        exit(EXIT_FAILURE);
    }

    globals.pOutputFile  = pFileHandle;
    globals.pUserData    = pUserData;

    // Subscribe to CUPTI
    if (((UserData *)pUserData)->skipCuptiSubscription == 0)
    {
        // If the user provides function pointer, subscribe CUPTI to that function pointer (pTraceCallback).
        // Else subscribe CUPTI to the common CuptiCallbackHandler.
        if (pTraceCallback)
        {
            CUPTI_API_CALL_VERBOSE(cuptiSubscribe(&globals.subscriberHandle, (CUpti_CallbackFunc)pTraceCallback, pUserData));
        }
        else
        {
            CUPTI_API_CALL_VERBOSE(cuptiSubscribe(&globals.subscriberHandle, (CUpti_CallbackFunc)CuptiCallbackHandler, pUserData));
        }


        // Enable CUPTI callback on context syncronization
        if (((UserData *)pUserData)->flushAtCtxSync)
        {
            CUPTI_API_CALL_VERBOSE(cuptiEnableCallback(1, globals.subscriberHandle, CUPTI_CB_DOMAIN_SYNCHRONIZE, CUPTI_CBID_SYNCHRONIZE_CONTEXT_SYNCHRONIZED));
        }

        // Enable CUPTI callback on stream syncronization
        if (((UserData *)pUserData)->flushAtStreamSync)
        {
            CUPTI_API_CALL_VERBOSE(cuptiEnableCallback(1, globals.subscriberHandle, CUPTI_CB_DOMAIN_SYNCHRONIZE, CUPTI_CBID_SYNCHRONIZE_STREAM_SYNCHRONIZED));
        }

        // Enable CUPTI callback on CUDA device reset by default
        CUPTI_API_CALL_VERBOSE(cuptiEnableCallback(1, globals.subscriberHandle, CUPTI_CB_DOMAIN_RUNTIME_API, CUPTI_RUNTIME_TRACE_CBID_cudaDeviceReset_v3020));

        // Enable CUPTI callback on fatal errors by default
        CUPTI_API_CALL_VERBOSE(cuptiEnableCallback(1, globals.subscriberHandle, CUPTI_CB_DOMAIN_STATE, CUPTI_CBID_STATE_FATAL_ERROR));
    }

    // Register callbacks for buffer requests and for buffers completed by CUPTI.
    globals.buffersRequested = 0;
    globals.buffersCompleted = 0;
    CUPTI_API_CALL_VERBOSE(cuptiActivityRegisterCallbacks(BufferRequested, BufferCompleted));

    // Optionally get and set activity attributes.
    // Attributes can be set by the CUPTI client to change behavior of the activity API.
    // Some attributes require to be set before any CUDA context is created to be effective,
    // E.g. To be applied to all device buffer allocations (see documentation).
    if ((((UserData *)pUserData))->deviceBufferSize != 0)
    {
        size_t attrValue = (((UserData *)pUserData))->deviceBufferSize;
        size_t attrValueSize = sizeof(size_t);
        CUPTI_API_CALL_VERBOSE(cuptiActivitySetAttribute(CUPTI_ACTIVITY_ATTR_DEVICE_BUFFER_SIZE, &attrValueSize, &attrValue));
        std::cout << "CUPTI_ACTIVITY_ATTR_DEVICE_BUFFER_SIZE = " << attrValue << " bytes.\n";
    }

    if ((((UserData *)pUserData))->activityBufferSize != 0)
    {
        globals.activityBufferSize = (((UserData *)pUserData))->activityBufferSize;
    }
    else
    {
        globals.activityBufferSize = BUF_SIZE;
    }

    std::cout << "Activity buffer size = " << globals.activityBufferSize << " bytes.\n";
}

static void
DeInitCuptiTrace(void)
{
    CUPTI_API_CALL(cuptiGetLastError());

    if (globals.subscriberHandle)
    {
        CUPTI_API_CALL_VERBOSE(cuptiUnsubscribe(globals.subscriberHandle));
    }

    CUPTI_API_CALL_VERBOSE(cuptiActivityFlushAll(1));

    if (globals.pUserData != NULL)
    {
        free(globals.pUserData);
    }
}

#endif // HELPER_CUPTI_ACTIVITY_H_
