_ CharRNN sample with TensorRT

Optimize Multi-layer Recurrent Neural Network with TensorRT

In this notebook we show step by step procedure to use TensorRT to optimize a trained character level language model implemented using multi layer Recurrent Neural Network.

STEP-1: Exporting Weights From a TensorFlow Model Checkpoint

  1. We have trained the model in TensorFlow and provided model checkpoint files that are created during training. The model used by this sample was trained using github repository: https://github.com/crazydonkey200/tensorflow-char-rnn
  2. A python script /usr/src/tensorrt/samples/common/dumpTFWts.py has been provided to extract the weights from the model checkpoint files that are created during training. Use dumpTFWts.py -h for directions on the usage of the script.
In [ ]:
// In our case, OUTPUT is char-rnn.wts
python dumpTFWts.py -m MODEL -o OUTPUT

STEP-2: Load Weights and Create an empty TensorRT network

In [ ]:
// Load weights from a formatted file into a map.
std::map<std::string, Weights> weightMap = loadWeights(locateFile("char-rnn.wts"), weightNames);

// create the builder and an empty network
IBuilder* builder = createInferBuilder(gLogger);
INetworkDefinition* network = builder->createNetwork();

STEP-3: Add the required layers to the Network

  • RNNv2 layer:
  • Create RNN layer by specifying the operation of the RNN cell as LSTM. Also supports relu, GRU, tanh.

    In [ ]:
    // LAYER_COUNT-> Number of stacked layers. In ourcase it's 2.
    // HIDDEN_SIZE-> Number of Hidden states. In ourcase it's 512.
    // SEQ_SIZE-> Sequence Size. In ourcase it's 1.
    auto rnn = network->addRNNv2(*data, LAYER_COUNT, HIDDEN_SIZE, SEQ_SIZE, RNNOperation::kLSTM);
    
    // convert tensorflow weight format to trt weight format
    Weights rnnwL0 = convertRNNWeights(weightMap[RNNW_L0_NAME]);
    Weights rnnbL0 = convertRNNBias(weightMap[RNNB_L0_NAME]);
    Weights rnnwL1 = convertRNNWeights(weightMap[RNNW_L1_NAME]);
    Weights rnnbL1 = convertRNNBias(weightMap[RNNB_L1_NAME]);
    
    // type casting 
    const nvinfer1::DataType dataType = static_cast<nvinfer1::DataType>(rnnwL0.type);
    const float * wtsL0 = static_cast<const float *>(rnnwL0.values);
    const float * biasesL0 = static_cast<const float *>(rnnbL0.values);
    const float * wtsL1 = static_cast<const float *>(rnnwL1.values);
    const float * biasesL1 = static_cast<const float *>(rnnbL1.values);
    
    // Iterates over the two layers and eight gates to set the correct gate weights and gate biases for the RNN layers.
    for (int gateIndex = 0, numGates = gateOrder.size(); gateIndex < 2 * numGates; gateIndex++)
    {
        // extract weights and bias for a given gate and layer
        Weights gateWeightL0{.type = dataType, .values = (void*)(wtsL0 + kernelOffset), .count = DATA_SIZE * HIDDEN_SIZE};
        Weights gateBiasL0{.type = dataType, .values = (void*)(biasesL0 + biasOffset), .count = HIDDEN_SIZE};
        Weights gateWeightL1{.type = dataType, .values = (void*)(wtsL1 + kernelOffset), .count = DATA_SIZE * HIDDEN_SIZE};
        Weights gateBiasL1{.type = dataType, .values = (void*)(biasesL1 + biasOffset), .count = HIDDEN_SIZE};
    
        // set weights and bias for given gate
        rnn->setWeightsForGate(0, gateOrder[gateIndex % numGates], (gateIndex < numGates), gateWeightL0);
        rnn->setBiasForGate(0, gateOrder[gateIndex % numGates], (gateIndex < numGates), gateBiasL0);
        rnn->setWeightsForGate(1, gateOrder[gateIndex % numGates], (gateIndex < numGates), gateWeightL1);
        rnn->setBiasForGate(1, gateOrder[gateIndex % numGates], (gateIndex < numGates), gateBiasL1);
        
        // Update offsets
        kernelOffset = kernelOffset + DATA_SIZE * HIDDEN_SIZE;
        biasOffset = biasOffset + HIDDEN_SIZE;
    
    }
    
    In [ ]:
    // Helper function to convert RNN weights from TensorFlow's format to TensorRT's format.
    Weights convertRNNWeights(Weights input)
    {
        float* ptr = static_cast<float*>(malloc(sizeof(float)*input.count));
        int dims[4]{2, HIDDEN_SIZE, 4, HIDDEN_SIZE};
        int order[4]{ 0, 3, 1, 2};
        utils::reshapeWeights(input, dims, order, ptr, 4);
        utils::transposeSubBuffers(ptr, DataType::kFLOAT, 2, HIDDEN_SIZE * HIDDEN_SIZE, 4);
        return Weights{input.type, ptr, input.count};
    }
    
    // Helper function to convert RNN Biases from TensorFlow's format to TensorRT's format.
    Weights convertRNNBias(Weights input)
    {
        const int sizeOfElement = samples_common::getElementSize(input.type);
        char* ptr = static_cast<char*>(malloc(sizeOfElement*input.count*2));
        const char* iptr = static_cast<const char*>(input.values);
        std::copy(iptr, iptr + 4 * HIDDEN_SIZE * sizeOfElement, ptr);
        std::fill(ptr + sizeOfElement * input.count, ptr + sizeOfElement * input.count * 2, 0);
        return Weights{input.type, ptr, input.count*2};
    }
    
  • Fully Connected Layer = MatrixMultiply Layer + ElementWise Layer
  • MatrixMultiply Layer setup: The Matrix Multiplication layer is used to execute the first step of the functionality provided by a FullyConnected layer. As shown in the code below, a Constant layer will need to be used so that the FullyConnected weights can be stored in the engine. The output of the Constant and RNN layers are then used as inputs to the Matrix Multiplication layer.

    In [ ]:
    // add Constant layers for fully connected weights
    auto fcwts = network->addConstant(Dims2(VOCAB_SIZE, HIDDEN_SIZE), weightMap[FCW_NAME]);
    
    // Add matrix multiplication layer for multiplying rnn output with FC weights
    auto matrixMultLayer = network->addMatrixMultiply(*fcwts->getOutput(0), false, *rnn->getOutput(0), true);
    

    ElementWise layer setup: The ElementWise layer is used to execute the second step of the functionality provided by a FullyConnected layer. The output of the fcbias Constant layer and Matrix Multiplication layer are used as inputs to the ElementWise layer. The output from this layer is then supplied to the TopK layer.

    In [ ]:
    // Add Constant layers for fully connected biases
    auto fcbias = network->addConstant(Dims2(VOCAB_SIZE, 1), weightMap[FCB_NAME]);
    
    // Add elementwise layer for adding bias
    auto addBiasLayer = network->addElementWise(*matrixMultLayer->getOutput(0), *fcbias->getOutput(0), ElementWiseOperation::kSUM);
    
  • TopK layer:
  • The TopK layer is used to identify the character that has the maximum probability of appearing next.

    Note: The layer has two outputs. The first output is an array of the top K values. The second, which is of more interest to us, is the index at which these maximum values appear. The code below sets up the TopK layer and assigns the OUTPUT_BLOB_NAME to the second output of the layer.

    In [ ]:
    // Add TopK layer to determine which character has highest probability.
    auto pred =  network->addTopK(*addBiasLayer->getOutput(0), nvinfer1::TopKOperation::kMAX, 1, reduceAxis);
    pred->getOutput(1)->setName(OUTPUT_BLOB_NAME);
    

    STEP-4:Build the TensorRT Inference Engine

    In [ ]:
    // Build the engine
    auto engine = builder->buildCudaEngine(*network); 
    
    // serialize engine
    (*modelStream) = engine->serialize();
    
    // clean up resources
    network->destroy();
    builder->destroy();
    

    STEP-5: Retrieving the Engine from Shared Memory

    In [ ]:
    // Initialize engine, context, and other runtime resources
    IRuntime* runtime = createInferRuntime(gLogger);
        
    ICudaEngine* engine = runtime->deserializeCudaEngine(modelStream->data(), modelStream->size(), nullptr);
    
    IExecutionContext *context = engine->createExecutionContext();
    

    STEP-6: Running Inference

    Given an input string, this function seeds the model and then generates the expected string. The TensorRT context is used to run the model. The input string is used to seed the model. weightMap contains all the weights required by the model.

    In [ ]:
    // Helper function that takes in context, input string, weightMap as input and then generates the expected string.
    bool doInference(IExecutionContext& context, std::string input, std::string expected, std::map<std::string, Weights> &weightMap)
    {
        const ICudaEngine& engine = context.getEngine();
    
        // allocate memory on host and device
        allocateMemory(engine, buffers, data, indices);
    
        // create stream for trt execution
        cudaStream_t stream;
        CHECK(cudaStreamCreate(&stream));
    
        auto embed = weightMap[EMBED_NAME];
        std::string genstr;
    
        // Seed the RNN with the input.
        for (auto &a : input)
        {
            std::copy(static_cast<const float*>(embed.values) + char_to_id[a]*DATA_SIZE,
                    static_cast<const float*>(embed.values) + char_to_id[a]*DATA_SIZE + DATA_SIZE,
                    data[INPUT_IDX]);
    
            stepOnce(data, buffers, indices, stream, context);
            cudaStreamSynchronize(stream);
    
            // Copy Ct/Ht to the Ct-1/Ht-1 slots. Ct and Ht these are outputs of LSTM unit
            std::memcpy(data[HIDDEN_IN_IDX], data[HIDDEN_OUT_IDX], gSizes[HIDDEN_IN_IDX] * sizeof(float));
            std::memcpy(data[CELL_IN_IDX], data[CELL_OUT_IDX], gSizes[CELL_IN_IDX] * sizeof(float));
    
            genstr.push_back(a);
        }
        // Extract first predicted character
        uint32_t predIdx = *reinterpret_cast<uint32_t *>(data[OUTPUT_IDX]);
        genstr.push_back(id_to_char[predIdx]);
    
        // Generate predicted sequence of characters.  
        // The following code simply selects the character with the highest probability. The final result is stored in genstr.
        for (size_t x = 0, y = expected.size() - 1; x < y; x++)
        {
            std::copy(static_cast<const float*>(embed.values) + char_to_id[*genstr.rbegin()]*DATA_SIZE,
                     static_cast<const float*>(embed.values) + char_to_id[*genstr.rbegin()]*DATA_SIZE + DATA_SIZE,
                     data[INPUT_IDX]);
    
            stepOnce(data, buffers, indices, stream, context);
            cudaStreamSynchronize(stream);
    
            // Copy Ct/Ht to the Ct-1/Ht-1 slots.
            std::memcpy(data[HIDDEN_IN_IDX], data[HIDDEN_OUT_IDX], gSizes[HIDDEN_IN_IDX] * sizeof(float));
            std::memcpy(data[CELL_IN_IDX], data[CELL_OUT_IDX], gSizes[CELL_IN_IDX] * sizeof(float));
    
            predIdx = *reinterpret_cast<uint32_t *>(data[OUTPUT_IDX]);
            genstr.push_back(id_to_char[predIdx]);
        }
    
        return genstr == (input + expected);
    }
    
    In [ ]:
    // data    -> The CPU buffers used to copy back and forth data from the engine.
    // buffers -> The engine buffers that will be used for input and output.
    // indices -> The indices that the engine has bound specific blobs to.
    // stream  -> The cuda stream used during execution.
    // context -> The TensorRT context used to run the model.
    
    const int INPUT_IDX = 0;
    const int HIDDEN_IN_IDX = 1;
    const int CELL_IN_IDX = 2;
    const int HIDDEN_OUT_IDX = 3;
    const int CELL_OUT_IDX = 4;
    
    void stepOnce(float **data, void **buffers, int *indices, cudaStream_t &stream, IExecutionContext &context)
    {
        // DMA the input to the GPU
        CHECK(cudaMemcpyAsync(buffers[indices[INPUT_IDX]], data[INPUT_IDX], gSizes[INPUT_IDX] * sizeof(float), cudaMemcpyHostToDevice, stream));
        CHECK(cudaMemcpyAsync(buffers[indices[HIDDEN_IN_IDX]], data[HIDDEN_IN_IDX], gSizes[HIDDEN_IN_IDX] * sizeof(float), cudaMemcpyHostToDevice, stream));
        CHECK(cudaMemcpyAsync(buffers[indices[CELL_IN_IDX]], data[CELL_IN_IDX], gSizes[CELL_IN_IDX] * sizeof(float), cudaMemcpyHostToDevice, stream));
    
        // Execute asynchronously
        context.enqueue(1, buffers, stream, nullptr);
    
        // DMA the output from the GPU
        CHECK(cudaMemcpyAsync(data[HIDDEN_OUT_IDX], buffers[indices[HIDDEN_OUT_IDX]], gSizes[HIDDEN_OUT_IDX] * sizeof(float), cudaMemcpyDeviceToHost, stream));
        CHECK(cudaMemcpyAsync(data[CELL_OUT_IDX], buffers[indices[CELL_OUT_IDX]], gSizes[CELL_OUT_IDX] * sizeof(float), cudaMemcpyDeviceToHost, stream));
        CHECK(cudaMemcpyAsync(data[OUTPUT_IDX], buffers[indices[OUTPUT_IDX]], gSizes[OUTPUT_IDX] * sizeof(float), cudaMemcpyDeviceToHost, stream));
    }
    
    In [ ]:
    
    
    
    In [ ]:
    
    
    
    In [ ]: