You need to enable JavaScript to run this app.
最新活动
大模型
产品
解决方案
定价
生态与合作
支持与服务
开发者
了解我们

CUDA Runtime与Driver API混合使用:初始化Runtime后上下文获取失败

CUDA Runtime初始化后,Driver API获取上下文报错CUDA_ERROR_INVALID_CONTEXT

我最近在尝试混合使用CUDA Runtime API和Driver API时遇到了一个问题:按照文档说明,Runtime初始化后应该可以通过cuCtxGetCurrent()获取其创建的主上下文,供Driver API调用,但实际运行时始终报错CUDA_ERROR_INVALID_CONTEXT

我的测试代码如下:

#define CUDA_DRIVER_API
#include <cuda.h>
#include <cuda_runtime.h>
#include <helper_cuda.h>
#include <iostream>

CUcontext check_current_ctx() {
    CUcontext context{0};
    unsigned int api_ver;
    checkCudaErrors(cuCtxGetCurrent(&context));
    fprintf(stdout, "current context=%p\n", context);
    checkCudaErrors( cuCtxGetApiVersion(context, &api_ver));
    fprintf(stdout, "current context api version = %d\n", api_ver);
    return context;
}

auto inital_runtime_context() {
    int current_device = 0;
    int device_count = 0;
    int devices_prohibited = 0;
    CUcontext current_ctx{0};
    cudaDeviceProp deviceProp;
    checkCudaErrors(cudaGetDeviceCount(&device_count));;
    if (device_count == 0) {
        fprintf(stderr, "CUDA error: no devices supporting CUDA.\n");
        exit(EXIT_FAILURE);
    }
    // Find the GPU which is selected by Vulkan
    while (current_device < device_count) {
        cudaGetDeviceProperties(&deviceProp, current_device);
        if ((deviceProp.computeMode != cudaComputeModeProhibited)) {
            checkCudaErrors(cudaSetDevice(current_device));
            checkCudaErrors(cudaGetDeviceProperties(&deviceProp, current_device));
            printf("GPU Device %d: \"%s\" with compute capability %d.%d\n\n", current_device, deviceProp.name, deviceProp.major, deviceProp.minor);
            CUcontext current_ctx;
            cuCtxGetCurrent(&current_ctx);
            std::cout << "current_ctx=" << current_ctx << "\n";
            return current_device;
        } else {
            devices_prohibited++;
        }
        current_device++;
    }
    if (devices_prohibited == device_count) {
        fprintf(stderr, "CUDA error:" " No Vulkan-CUDA Interop capable GPU found.\n");
        exit(EXIT_FAILURE);
    }
    return -1;
}

void test_runtime_driver_op() {
    inital_runtime_context();
    check_current_ctx();
}

运行后输出如下错误:

GPU Device 0: "GeForce RTX ..." with compute capability 7.5
current_ctx=0x6eb220
current context=0x6eb220
CUDA error at ... code=201(CUDA_ERROR_INVALID_CONTEXT) "cuCtxGetApiVersion(context, &api_ver)"

问题原因分析

这个问题的核心在于CUDA Runtime对主上下文的自动生命周期管理

当你在inital_runtime_context函数中调用cudaSetDevice时,Runtime会隐式初始化并创建对应设备的主上下文,同时将其绑定到当前线程。但当函数执行完毕返回后,当前线程没有任何活跃的Runtime API调用,Runtime会自动将这个主上下文从当前线程解绑(甚至在某些情况下销毁它)。

此时你调用cuCtxGetCurrent()拿到的只是之前上下文的内存地址,但这个上下文已经不再有效,因此后续调用cuCtxGetApiVersion就会抛出CUDA_ERROR_INVALID_CONTEXT错误。

而你之前测试的「先Driver API创建上下文,再用Runtime API」场景没问题,是因为Driver API创建的上下文生命周期由用户手动管理,Runtime不会自动清理它,所以可以正常复用。

解决方案

这里提供两种可行的解决思路:

1. 在调用Driver API前,重新绑定Runtime主上下文

通过调用一个简单的Runtime API(比如cudaSetDevicecudaGetDevice),让Runtime将主上下文重新绑定到当前线程,确保上下文处于活跃状态。

修改test_runtime_driver_op函数:

void test_runtime_driver_op() {
    int dev = inital_runtime_context();
    // 重新绑定主上下文到当前线程
    checkCudaErrors(cudaSetDevice(dev));
    check_current_ctx();
}

或者直接在check_current_ctx开头添加一行:

CUcontext check_current_ctx() {
    // 触发Runtime上下文绑定,确保上下文有效
    int current_dev;
    checkCudaErrors(cudaGetDevice(&current_dev));
    
    CUcontext context{0};
    unsigned int api_ver;
    checkCudaErrors(cuCtxGetCurrent(&context));
    fprintf(stdout, "current context=%p\n", context);
    checkCudaErrors( cuCtxGetApiVersion(context, &api_ver));
    fprintf(stdout, "current context api version = %d\n", api_ver);
    return context;
}

2. 显式增加上下文引用计数,防止Runtime销毁

使用Driver API的cuCtxRetain函数增加上下文的引用计数,这样Runtime就不会自动销毁这个上下文,直到你调用cuCtxRelease释放引用。

修改inital_runtime_context中获取上下文的部分:

// ... 原有代码 ...
printf("GPU Device %d: \"%s\" with compute capability %d.%d\n\n", current_device, deviceProp.name, deviceProp.major, deviceProp.minor);
CUcontext current_ctx;
checkCudaErrors(cuCtxGetCurrent(&current_ctx));
// 增加引用计数,阻止Runtime自动销毁上下文
checkCudaErrors(cuCtxRetain(&current_ctx));
std::cout << "current_ctx=" << current_ctx << "\n";
return current_device;
// ... 原有代码 ...

注意:使用这种方式后,记得在不需要上下文的时候调用cuCtxRelease(current_ctx)来释放引用,避免内存泄漏。


内容的提问来源于stack exchange,提问作者Wang

火山引擎 最新活动