You need to enable JavaScript to run this app.
最新活动
大模型
产品
解决方案
定价
生态与合作
支持与服务
开发者
了解我们

在TensorFlow自定义C++ Op中调用标准Op时Bazel编译报错求助

解决自定义TensorFlow C++ Op的依赖错误与预定义Op调用问题

嘿,我刚碰到过类似的问题,你的问题根源其实是在Op内核里误用了TensorFlow的客户端API,咱们一步步来拆解解决:

1. 先搞定Bazel的依赖错误

你看到的tensorflow/cc:cc_ops cannot depend on tensorflow/core:framework错误,本质是因为tensorflow/cc模块是给外部客户端(比如写C++/Python调用TF的业务代码)用的高层API,而Op内核属于TF执行引擎的核心组件,归tensorflow/core范畴,两者的依赖链条是互斥的,根本不能混着用。

要修复这个,你需要做两件事:

  • 清理代码里的客户端API:把#include "tensorflow/cc/client/client_session.h"#include "tensorflow/cc/ops/standard_ops.h"这些头文件删掉,还有ScopeClientSession这些客户端相关的代码也全部移除——Op内核里完全不能这么用。
  • 调整BUILD文件的依赖:把//tensorflow/cc相关的依赖删掉,只保留核心的内核依赖:
    load("//tensorflow:tensorflow.bzl", "tf_custom_op_library")
    
    tf_custom_op_library(
        name = "MyNewOp.so",
        srcs = ["mynewop.cc"],
        deps = [
            "//tensorflow/core:framework",
            "//tensorflow/core:tensorflow",
            "//tensorflow/core/kernels:matmul_op", # 加上MatMul内核的依赖
        ],
    )
    

2. 正确在自定义Op内核里调用预定义Op

当然可以调用预定义Op!但绝对不能用客户端那套构建图、跑Session的方式,得用TF内核层的底层API直接调用目标Op的Compute逻辑。

以你要调用的MatMul为例,直接调用它的内核计算函数就行,下面是调整后的完整核心代码:

REGISTER_OP("NewOp")
    .Input("input: int32")
    .Output("output: int32")
    .SetShapeFn([](::tensorflow::shape_inference::InferenceContext* c) {
      c->set_output(0, c->input(0));
      return Status::OK();
    });

#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/kernels/matmul_op.h" // 引入MatMul的内核实现头文件
#include "tensorflow/core/framework/op.h"

using namespace tensorflow;

class MyNewOp : public OpKernel {
 public:
  explicit MyNewOp(OpKernelConstruction* context) : OpKernel(context) {}

  void Compute(OpKernelContext* context) override {
    // 1. 获取输入张量(保留你原来的逻辑)
    const Tensor& input_tensor = context->input(0);

    // 2. 准备MatMul的输入数据(这里用你原来的示例数据,可按需替换)
    Tensor A(DT_FLOAT, TensorShape({2, 2}));
    auto A_flat = A.flat<float>();
    A_flat(0) = 35.f; A_flat(1) = 22.f;
    A_flat(2) = -10.f; A_flat(3) = 0.f;

    Tensor b(DT_FLOAT, TensorShape({1, 2}));
    auto b_flat = b.flat<float>();
    b_flat(0) = 30.f; b_flat(1) = 55.f;

    // 3. 设置MatMul的参数(和你原来的TransposeB(true)对应)
    MatMulOpParams params;
    params.transpose_a = false;
    params.transpose_b = true;
    params.adjoint_a = false;
    params.adjoint_b = false;

    // 4. 申请MatMul的输出张量空间
    Tensor* matmul_output = nullptr;
    TensorShape output_shape = MatMulShape(A.shape(), b.shape(), params);
    OP_REQUIRES_OK(context, context->allocate_temp(DT_FLOAT, output_shape, &matmul_output));

    // 5. 直接调用MatMul的Compute逻辑
    OP_REQUIRES_OK(context, MatMulCompute<float>(context, A, b, *matmul_output, params));

    // 6. 处理结果并设置自定义Op的输出(示例逻辑,按需调整)
    Tensor* output_tensor = nullptr;
    OP_REQUIRES_OK(context, context->allocate_output(0, input_tensor.shape(), &output_tensor));
    auto output_flat = output_tensor->flat<int32>();
    
    // 这里只是把MatMul的浮点结果转成int32,你可以换成自己的业务逻辑
    auto matmul_flat = matmul_output->flat<float>();
    for (int i = 0; i < output_flat.size(); ++i) {
      output_flat(i) = static_cast<int32>(matmul_flat(i));
    }
  }
};

REGISTER_KERNEL_BUILDER(Name("NewOp").Device(DEVICE_CPU), MyNewOp);

几个关键提醒

  • Op内核是跑在TF的执行流里的,绝对不能创建新的ClientSession或者构建独立的图,所有操作都要基于当前的OpKernelContext来做。
  • 调用其他Op的Compute函数时,一定要确认输入张量的形状、类型和参数都符合目标Op的要求,不然很容易出现形状不匹配或者类型错误。
  • 如果要调用的Op有CPU/GPU多个实现,你可以通过OpKernelContext::CreateOpKernel来获取适配当前设备的内核实例,这样能自动适配不同设备。

内容的提问来源于stack exchange,提问作者Charles Tao

火山引擎 最新活动