You need to enable JavaScript to run this app.
最新活动
大模型
产品
解决方案
定价
生态与合作
支持与服务
开发者
了解我们

CUDA中静态指针传函数参数及全局函数调用设备函数的问题

在CUDA中通过静态函数指针传递设备函数到全局函数的实现

看起来你已经有了一个不错的开头,这里我会帮你完善代码,并且解释清楚在CUDA中如何正确地将设备函数作为参数传递给全局函数(核函数)。

首先,CUDA里的设备函数指针和主机端的函数指针有本质区别:主机端不能直接获取设备函数的地址,必须通过CUDA runtime API来获取设备端的有效地址。下面是完整的可运行实现:

#include <stdio.h>
#include <cuda.h>
#include <math.h>
#include <cuda_runtime.h>
#include <cuda_runtime_api.h>
#include "device_launch_parameters.h"

// 设备函数定义
__device__ void foo(int i) {
    printf("Hello from foo! Thread ID: %d\n", i);
}

__device__ void bar(int i) {
    printf("Hello from bar! Thread ID: %d\n", i);
}

// 定义设备函数指针类型,签名必须和目标设备函数完全匹配
typedef void(*FunctionCallback)(int);

// 设备端的函数调用器(可选,用来封装调用逻辑)
__device__ void callFunction(FunctionCallback funcx, int val) {
    funcx(val);
}

// 全局函数(核函数),接收设备函数指针作为参数
__global__ void kernel(FunctionCallback func) {
    int threadId = threadIdx.x;
    // 调用传入的设备函数,传递当前线程ID
    callFunction(func, threadId);
}

int main() {
    // 存储设备函数的指针(设备端地址)
    FunctionCallback d_foo_ptr, d_bar_ptr;

    // 通过cudaGetSymbolAddress获取设备函数的有效地址
    cudaGetSymbolAddress((void**)&d_foo_ptr, foo);
    cudaGetSymbolAddress((void**)&d_bar_ptr, bar);

    // 启动核函数,传入foo的设备函数指针
    printf("Running kernel with foo:\n");
    kernel<<<1, 3>>>(d_foo_ptr);
    cudaDeviceSynchronize(); // 等待核函数执行完成

    // 启动核函数,传入bar的设备函数指针
    printf("\nRunning kernel with bar:\n");
    kernel<<<1, 3>>>(d_bar_ptr);
    cudaDeviceSynchronize();

    // 检查CUDA运行时错误
    cudaError_t err = cudaGetLastError();
    if (err != cudaSuccess) {
        fprintf(stderr, "CUDA Error: %s\n", cudaGetErrorString(err));
        return 1;
    }

    return 0;
}

关键细节解释:

  • 函数指针类型匹配FunctionCallback的签名必须和foobar完全一致(返回值、参数类型/数量),否则会导致未定义行为。
  • 获取设备函数地址:主机端不能直接用&foo获取设备函数地址(这会得到主机端视角的无效地址),必须用cudaGetSymbolAddress来获取设备端的有效符号地址,因为设备函数是存储在设备全局内存中的符号。
  • 核函数传递指针:核函数运行在设备端,所以可以直接接收设备函数指针并调用,不需要额外的内存拷贝。
  • 同步与错误检查cudaDeviceSynchronize()确保核函数执行完成后再打印结果,避免输出混乱;cudaGetLastError()用来捕获核函数启动或执行中的错误。

额外注意事项:

  • 如果是在设备函数内部传递函数指针(比如你代码里的call1函数),直接使用FunctionCallback类型即可,不需要通过cudaGetSymbolAddress,因为已经在设备端上下文里了。
  • 绝对不能把主机函数的指针传递给核函数,设备端无法访问主机端的代码空间,会触发非法访问错误。
  • 现代CUDA设备(SM 2.0及以上)都支持函数指针,如果你的设备太老可能需要升级或调整编译选项。

内容的提问来源于stack exchange,提问作者Dao Thanh

火山引擎 最新活动