如何通过TensorFlow C++ API加载并运行自定义Op?
我刚好做过类似的事情,给你梳理下TensorFlow C++ API调用自定义Op的完整流程,和Python的逻辑是对应的,只是API细节不同:
1. 加载自定义Op动态库(对应Python的
tf.load_op_library) 在C++里,你需要用tensorflow::LoadLibrary函数来加载你的.so文件,它会把自定义Op注册到TensorFlow的全局Op注册表中,和Python的load_op_library作用完全一致。
首先要包含必要的头文件,然后编写加载代码:
#include "tensorflow/core/public/session.h" #include "tensorflow/core/framework/op.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/graph/graph_def_builder.h" int main() { // 加载自定义Op的.so库 tensorflow::Status load_status = tensorflow::LoadLibrary("/path/to/your/custom_op.so", nullptr); if (!load_status.ok()) { std::cerr << "加载自定义Op失败: " << load_status.ToString() << std::endl; return 1; } // 后续构建图和运行的代码...
2. 构建包含自定义Op的计算图(对应Python的
custom_mod.exec()) 加载完库后,你需要在计算图中添加这个自定义Op节点。这里要注意:Op的名称必须和你自定义Op注册时REGISTER_OP宏里的名称完全一致(比如你Python里调用的custom_mod.exec(),对应的注册名称应该是"Exec",具体要看你自定义Op的C++注册代码)。
推荐用Scope API来简洁构建图,示例如下:
// 创建会话选项和会话实例,指定GPU设备 tensorflow::SessionOptions session_options; session_options.config.mutable_gpu_options()->set_visible_device_list("0"); std::unique_ptr<tensorflow::Session> session(tensorflow::NewSession(session_options)); // 创建根Scope tensorflow::Scope root = tensorflow::Scope::NewRootScope(); // 添加自定义Op节点到图中 // 假设你的Op没有输入,有一个输出;如果有输入/属性,需要对应添加 auto custom_op = root.op("Exec") // 这里的"Exec"就是注册时的Op名称 .WithOpName("CustomExec") // 给这个节点起个名字,方便后续引用 .Output(0); // 获取Op的第一个输出 // 将Scope中的图转换为GraphDef tensorflow::GraphDef graph_def; tensorflow::Status graph_status = root.ToGraphDef(&graph_def); if (!graph_status.ok()) { std::cerr << "构建计算图失败: " << graph_status.ToString() << std::endl; return 1; } // 将GraphDef加载到会话中 tensorflow::Status create_status = session->Create(graph_def); if (!create_status.ok()) { std::cerr << "加载图到会话失败: " << create_status.ToString() << std::endl; return 1; }
3. 运行会话执行自定义Op(对应Python的
sess.run()) 最后一步就是运行会话,执行这个自定义Op节点,和Python的sess.run()逻辑一致:
// 准备存储输出的容器 std::vector<tensorflow::Tensor> outputs; // 运行Op:第一个参数是输入映射(无输入则为空),第二个是要获取的输出节点名称 tensorflow::Status run_status = session->Run( {}, // 输入张量映射,无输入则留空 {"CustomExec:0"}, // 要获取的输出节点(格式:节点名:输出索引) {}, // 要运行的节点列表(空则默认运行所有必要节点) &outputs ); if (!run_status.ok()) { std::cerr << "执行自定义Op失败: " << run_status.ToString() << std::endl; return 1; } // 如果Op有输出,可以在这里处理 if (!outputs.empty()) { std::cout << "自定义Op执行结果: " << outputs[0].DebugString() << std::endl; } // 关闭会话 session->Close(); return 0; }
关键注意事项
- 确保你的自定义Op
.so库是用与当前TensorFlow C++库版本一致的编译环境编译的,否则会出现链接或运行时错误。 - 如果你的Op需要输入张量,要在
session->Run()的第一个参数中传入对应的张量映射,比如{{"input_node_name", input_tensor}}。 - 如果Op有自定义属性(Attr),要在构建图时用
.Attr("attr_name", attr_value)来设置,和Python调用时的参数对应。
内容的提问来源于stack exchange,提问作者user3085931




