You need to enable JavaScript to run this app.
优惠活动
大模型
产品
解决方案
定价
更多
文档控制台
免费开始使用

如何修改CRNN模型的CTC层以输出文字识别置信度分数?

给CRNN+CTC模型添加识别置信度的TensorFlow实现方案

我刚好做过类似的单字图像CRNN识别任务,结合你用的TensorFlow框架和那个CRNN实现,给你几个实用的修改思路,帮你同时输出识别文字和置信度分数:

核心思路:利用CTC解码的概率信息

CTC层的输出是每个时间步的类别概率分布,我们可以从解码过程中直接提取序列级的置信度,或者针对每个字符计算单独的置信度。下面是两种最容易落地的方案:

方案1:用Beam Search解码直接获取序列置信度

这是最简单高效的方式,因为tf.nn.ctc_beam_search_decoder会直接返回每个候选序列的对数似然概率,我们只需要把它转换成普通概率就能作为整个识别序列的置信度。

修改原代码的解码部分:

# 原CRNN模型输出logits,shape为 [batch_size, max_time, num_classes]
logits = crnn_model(inputs)
seq_len = tf.fill([tf.shape(logits)[0]], tf.shape(logits)[1])  # 假设每个样本的时间步一致

# 替换贪心解码为beam search,同时获取解码结果和对数概率
decoded, log_probs = tf.nn.ctc_beam_search_decoder(
    tf.transpose(logits, (1, 0, 2)),  # CTC要求输入格式为 [max_time, batch_size, num_classes]
    seq_len,
    beam_width=10  # beam宽度可以根据需求调整,越大准确率越高但速度稍慢
)

# 将对数概率转换为0-1之间的置信度
confidences = tf.exp(log_probs)

# 把稀疏解码结果转成可读文本(这里假设你有char_list字符映射表)
dense_decoded = tf.sparse.to_dense(decoded[0], default_value=-1)
texts = tf.map_fn(
    lambda x: tf.strings.reduce_join(tf.gather(char_list, tf.boolean_mask(x, x != -1))),
    dense_decoded,
    dtype=tf.string
)

这样你就能同时得到texts(识别结果)和confidences(每个结果的置信度),置信度越接近1,说明识别结果越可靠。

方案2:贪心解码下手动计算字符/序列置信度

如果你坚持用贪心解码,那需要手动从每个时间步的概率中提取对应字符的置信度:

  1. 先对CRNN的logits做softmax,得到每个时间步的类别概率
  2. 解码后,对齐每个识别字符对应的时间步区间(因为CTC会合并重复字符)
  3. 取每个字符对应时间步中的最大概率,再计算平均/乘积作为序列置信度

示例代码:

logits = crnn_model(inputs)
probs = tf.nn.softmax(logits, axis=-1)  # 得到每个时间步的类别概率,shape [batch, max_time, num_classes]
seq_len = tf.fill([tf.shape(logits)[0]], tf.shape(logits)[1])

# 贪心解码
decoded, _ = tf.nn.ctc_greedy_decoder(tf.transpose(logits, (1, 0, 2)), seq_len)
dense_decoded = tf.sparse.to_dense(decoded[0], default_value=-1)
blank_idx = num_classes - 1  # 假设空白符是最后一个类别

# 定义函数计算单个样本的置信度
def calc_single_confidence(sample_probs, sample_seq, seq_len):
    # 过滤掉空白符
    valid_chars = tf.boolean_mask(sample_seq, sample_seq != blank_idx)
    if tf.shape(valid_chars)[0] == 0:
        return 0.0
    
    # 遍历字符,找到每个字符对应的时间步区间
    conf_scores = []
    current_char = valid_chars[0]
    start_step = 0
    
    for i in range(1, tf.shape(valid_chars)[0]):
        if valid_chars[i] != current_char:
            # 取当前字符对应时间步的最大概率
            char_prob = tf.reduce_max(sample_probs[start_step:i, current_char])
            conf_scores.append(char_prob)
            current_char = valid_chars[i]
            start_step = i
    # 处理最后一个字符
    char_prob = tf.reduce_max(sample_probs[start_step:seq_len, current_char])
    conf_scores.append(char_prob)
    
    # 序列置信度取所有字符置信度的平均值
    return tf.reduce_mean(conf_scores)

# 批量计算置信度
confidences = tf.map_fn(
    lambda idx: calc_single_confidence(probs[idx], dense_decoded[idx], seq_len[idx]),
    tf.range(tf.shape(dense_decoded)[0]),
    dtype=tf.float32
)

落地注意事项

  • 优先选方案1,不仅实现简单,beam search的识别准确率通常也比贪心解码高
  • 置信度是0-1之间的数值,你可以设置阈值(比如0.8)过滤低置信度的识别结果
  • 确保字符映射表(char_list)和模型训练时的一致,避免文本转换错误

内容的提问来源于stack exchange,提问作者vinayak A

火山引擎 最新活动