You need to enable JavaScript to run this app.
最新活动
大模型
产品
解决方案
定价
生态与合作
支持与服务
开发者
了解我们

如何用TensorFlow获取细胞分类的混淆矩阵样本及分类图像列表

解决思路:获取细胞分类结果及各类样本列表

看起来你已经完成了相当扎实的细胞图像分类工作,要导出分类结果和各类样本列表用于人工检查,其实可以通过以下几个清晰的步骤来实现,全程用TensorFlow/Keras就能搞定:

步骤1:加载测试集数据与训练好的模型

首先得确保你能把测试集的图像路径真实标签和模型预测结果一一对应起来。如果之前用ImageDataGeneratorflow_from_directory加载测试集,那可以直接复用这个生成器,或者手动构建包含路径和标签的数据集:

import tensorflow as tf
from tensorflow.keras.models import load_model
import pandas as pd
import os

# 加载训练好的模型
model = load_model('your_trained_model.h5')  # 替换为你的模型保存路径

# 假设测试集按文件夹组织:test/live_cells, test/dead_cells
test_datagen = tf.keras.preprocessing.image.ImageDataGenerator(rescale=1./255)
test_generator = test_datagen.flow_from_directory(
    'path/to/test_set',
    target_size=(your_img_height, your_img_width),  # 替换为你的模型输入尺寸
    batch_size=32,
    class_mode='binary',  # 二分类任务:活/死细胞
    shuffle=False  # 关键!关闭打乱,保证文件名和预测结果顺序完全匹配
)

# 获取测试集的真实标签和图像路径(可转换为绝对路径更方便查看)
true_labels = test_generator.classes  # 0/1对应你定义的类别(比如0=死细胞,1=活细胞)
image_paths = [os.path.join(test_generator.directory, path) for path in test_generator.filenames]

如果你的测试集不是按文件夹组织,而是有单独的标签文件(比如CSV),可以用pd.read_csv加载路径和标签,再用自定义数据管道处理图像。

步骤2:获取模型预测结果

接下来对测试集进行预测,得到每个样本的预测类别和置信度:

# 获取预测概率(二分类下输出的是活细胞的概率)
pred_probs = model.predict(test_generator, verbose=1)
# 转换为预测类别(比如概率>0.5判定为活细胞,对应1;可根据需求调整阈值)
pred_labels = (pred_probs > 0.5).astype(int).flatten()

步骤3:分类筛选各类样本

现在有了真实标签、预测标签和图像路径,就可以筛选出你需要的所有列表:

1. 活/死细胞分类结果列表

# 模型分类为活细胞的图像路径
live_predicted = [path for path, pred in zip(image_paths, pred_labels) if pred == 1]
# 模型分类为死细胞的图像路径
dead_predicted = [path for path, pred in zip(image_paths, pred_labels) if pred == 0]

2. TP/FP/FN/TN样本列表

先明确二分类的定义(假设1=活细胞0=死细胞):

  • True Positives (TP):真实是活细胞,模型预测活细胞
  • False Positives (FP):真实是死细胞,模型预测活细胞
  • False Negatives (FN):真实是活细胞,模型预测死细胞
  • True Negatives (TN):真实是死细胞,模型预测死细胞

我们可以先把所有数据整理成DataFrame,方便后续筛选和查看:

results_df = pd.DataFrame({
    'image_path': image_paths,
    'true_label': true_labels,
    'pred_label': pred_labels,
    'pred_prob': pred_probs.flatten()
})

# 筛选各类样本路径
tp_samples = results_df[(results_df['true_label'] == 1) & (results_df['pred_label'] == 1)]['image_path'].tolist()
fp_samples = results_df[(results_df['true_label'] == 0) & (results_df['pred_label'] == 1)]['image_path'].tolist()
fn_samples = results_df[(results_df['true_label'] == 1) & (results_df['pred_label'] == 0)]['image_path'].tolist()
tn_samples = results_df[(results_df['true_label'] == 0) & (results_df['pred_label'] == 0)]['image_path'].tolist()

步骤4:导出列表到文件

最后把这些列表导出成CSV或TXT文件,方便人工检查:

# 导出完整结果到CSV(包含路径、真实标签、预测标签、置信度,方便全面排查)
results_df.to_csv('cell_classification_full_results.csv', index=False)

# 分别导出各类样本列表到TXT,方便快速定位
def save_list_to_txt(file_path, data_list):
    with open(file_path, 'w') as f:
        for item in data_list:
            f.write(f"{item}\n")

save_list_to_txt('live_predicted.txt', live_predicted)
save_list_to_txt('dead_predicted.txt', dead_predicted)
save_list_to_txt('tp_samples.txt', tp_samples)
save_list_to_txt('fp_samples.txt', fp_samples)
save_list_to_txt('fn_samples.txt', fn_samples)
save_list_to_txt('tn_samples.txt', tn_samples)

额外实用提示

  • 对于FN和FP样本,建议优先查看CSV里置信度接近阈值(比如0.4~0.6)的样本,这些往往是模型最容易混淆的案例,排查价值最高
  • 如果你的类别定义和示例相反(比如0=活细胞),只需要调整筛选条件里的标签值即可
  • 要是发现图像路径有错误,可以直接在CSV里修改路径,或者回溯标注阶段的数据源问题

内容的提问来源于stack exchange,提问作者chalbiophysics

火山引擎 最新活动