如何用TensorFlow获取细胞分类的混淆矩阵样本及分类图像列表
解决思路:获取细胞分类结果及各类样本列表
看起来你已经完成了相当扎实的细胞图像分类工作,要导出分类结果和各类样本列表用于人工检查,其实可以通过以下几个清晰的步骤来实现,全程用TensorFlow/Keras就能搞定:
步骤1:加载测试集数据与训练好的模型
首先得确保你能把测试集的图像路径、真实标签和模型预测结果一一对应起来。如果之前用ImageDataGenerator的flow_from_directory加载测试集,那可以直接复用这个生成器,或者手动构建包含路径和标签的数据集:
import tensorflow as tf from tensorflow.keras.models import load_model import pandas as pd import os # 加载训练好的模型 model = load_model('your_trained_model.h5') # 替换为你的模型保存路径 # 假设测试集按文件夹组织:test/live_cells, test/dead_cells test_datagen = tf.keras.preprocessing.image.ImageDataGenerator(rescale=1./255) test_generator = test_datagen.flow_from_directory( 'path/to/test_set', target_size=(your_img_height, your_img_width), # 替换为你的模型输入尺寸 batch_size=32, class_mode='binary', # 二分类任务:活/死细胞 shuffle=False # 关键!关闭打乱,保证文件名和预测结果顺序完全匹配 ) # 获取测试集的真实标签和图像路径(可转换为绝对路径更方便查看) true_labels = test_generator.classes # 0/1对应你定义的类别(比如0=死细胞,1=活细胞) image_paths = [os.path.join(test_generator.directory, path) for path in test_generator.filenames]
如果你的测试集不是按文件夹组织,而是有单独的标签文件(比如CSV),可以用pd.read_csv加载路径和标签,再用自定义数据管道处理图像。
步骤2:获取模型预测结果
接下来对测试集进行预测,得到每个样本的预测类别和置信度:
# 获取预测概率(二分类下输出的是活细胞的概率) pred_probs = model.predict(test_generator, verbose=1) # 转换为预测类别(比如概率>0.5判定为活细胞,对应1;可根据需求调整阈值) pred_labels = (pred_probs > 0.5).astype(int).flatten()
步骤3:分类筛选各类样本
现在有了真实标签、预测标签和图像路径,就可以筛选出你需要的所有列表:
1. 活/死细胞分类结果列表
# 模型分类为活细胞的图像路径 live_predicted = [path for path, pred in zip(image_paths, pred_labels) if pred == 1] # 模型分类为死细胞的图像路径 dead_predicted = [path for path, pred in zip(image_paths, pred_labels) if pred == 0]
2. TP/FP/FN/TN样本列表
先明确二分类的定义(假设1=活细胞,0=死细胞):
- True Positives (TP):真实是活细胞,模型预测活细胞
- False Positives (FP):真实是死细胞,模型预测活细胞
- False Negatives (FN):真实是活细胞,模型预测死细胞
- True Negatives (TN):真实是死细胞,模型预测死细胞
我们可以先把所有数据整理成DataFrame,方便后续筛选和查看:
results_df = pd.DataFrame({ 'image_path': image_paths, 'true_label': true_labels, 'pred_label': pred_labels, 'pred_prob': pred_probs.flatten() }) # 筛选各类样本路径 tp_samples = results_df[(results_df['true_label'] == 1) & (results_df['pred_label'] == 1)]['image_path'].tolist() fp_samples = results_df[(results_df['true_label'] == 0) & (results_df['pred_label'] == 1)]['image_path'].tolist() fn_samples = results_df[(results_df['true_label'] == 1) & (results_df['pred_label'] == 0)]['image_path'].tolist() tn_samples = results_df[(results_df['true_label'] == 0) & (results_df['pred_label'] == 0)]['image_path'].tolist()
步骤4:导出列表到文件
最后把这些列表导出成CSV或TXT文件,方便人工检查:
# 导出完整结果到CSV(包含路径、真实标签、预测标签、置信度,方便全面排查) results_df.to_csv('cell_classification_full_results.csv', index=False) # 分别导出各类样本列表到TXT,方便快速定位 def save_list_to_txt(file_path, data_list): with open(file_path, 'w') as f: for item in data_list: f.write(f"{item}\n") save_list_to_txt('live_predicted.txt', live_predicted) save_list_to_txt('dead_predicted.txt', dead_predicted) save_list_to_txt('tp_samples.txt', tp_samples) save_list_to_txt('fp_samples.txt', fp_samples) save_list_to_txt('fn_samples.txt', fn_samples) save_list_to_txt('tn_samples.txt', tn_samples)
额外实用提示
- 对于FN和FP样本,建议优先查看CSV里置信度接近阈值(比如0.4~0.6)的样本,这些往往是模型最容易混淆的案例,排查价值最高
- 如果你的类别定义和示例相反(比如0=活细胞),只需要调整筛选条件里的标签值即可
- 要是发现图像路径有错误,可以直接在CSV里修改路径,或者回溯标注阶段的数据源问题
内容的提问来源于stack exchange,提问作者chalbiophysics




