PyTorch中Siamese network的相似度/距离计算、分类方法咨询及代码问题排查
PyTorch中Siamese network的相似度/距离计算、分类方法咨询及代码问题排查
嗨,我来帮你理清楚Siamese网络的这些问题~ 先从你的代码问题说起,再一步步讲相似度计算和分类的方法。
一、你的代码里的小问题得先改
首先要敲黑板:你把ResNet的num_classes=5设置错啦!Triplet Loss是基于特征向量的距离来训练的,不是分类任务的logits输出。你现在forward_once返回的是5类的分类结果,这根本不是我们用来算距离的特征向量啊。
给你改个正确的版本:
import torch import torch.nn as nn import torchvision.models as models class SiameseNetwork(nn.Module): def __init__(self, feature_dim=256) -> None: super().__init__() # 加载ResNet18,这里可以根据需求加pretrained=True用预训练权重 self.resnet = models.resnet18(pretrained=False) # 拿到ResNet倒数第二层的输出维度(默认是512) in_features = self.resnet.fc.in_features # 把最后一层全连接替换成输出feature_dim维的特征层 self.resnet.fc = nn.Linear(in_features, feature_dim) def forward_once(self, item): output = self.resnet(item) # 加个L2归一化,让特征落在单位球面上,对Triplet Loss训练更友好 output = nn.functional.normalize(output, p=2, dim=1) return output def forward(self, anchor, positive, negative): output1 = self.forward_once(anchor) output2 = self.forward_once(positive) output3 = self.forward_once(negative) return output1, output2, output3
解释下:这样改了之后,forward_once返回的是归一化后的特征向量,完全适合Triplet Loss的训练和后续的距离计算。
二、相似度/距离怎么算?
训练的时候,TripletMarginLoss默认用的是欧氏距离(L2距离),你也可以通过参数指定用L1距离。到了推理阶段,常用的计算方式有这几种:
- 欧氏距离:数值越小,两个样本越像。直接用PyTorch的API就能算:
# feat1、feat2是两个样本的特征,shape都是[batch_size, feature_dim] euclidean_dist = torch.nn.functional.pairwise_distance(feat1, feat2, p=2) - 余弦相似度:取值在[-1,1]之间,越接近1说明两个样本越相似:
cos_sim = torch.nn.functional.cosine_similarity(feat1, feat2, dim=1) - 要是你想把距离转成0-1之间的相似度,直接用
1 / (1 + euclidean_dist)就行,数值越接近1,样本越相似。
三、训练好之后怎么分类?
分两种常见的任务场景说:
两样本匹配(比如判断是不是同一类)
训练完模型后,你可以用验证集找一个合适的距离阈值。比如计算两个样本的欧氏距离,如果小于这个阈值(比如0.5),就判定它们是同一类,否则不是。这个阈值得自己调,找验证集上效果最好的那个值就行。多分类任务(比如判断样本属于哪一类)
你可以先给每个类算一个“中心特征”:比如把每个类的所有样本都提取特征,然后取平均值作为这个类的中心。推理的时候,把待预测样本的特征和所有类的中心算距离,选距离最小的那个类作为预测结果。
给你个简单的代码片段参考:# 假设class_centers是个字典,key是类标签,value是该类的中心特征(shape: [feature_dim]) # feat是待预测样本的特征(shape: [feature_dim]) min_dist = float('inf') pred_label = -1 for label, center in class_centers.items(): dist = torch.norm(feat - center, p=2) # 计算欧氏距离 if dist < min_dist: min_dist = dist pred_label = label
额外的小提示
- 你用TripletMarginLoss的时候,输入顺序是(anchor, positive, negative),和你的forward输出对应,这部分是对的,不用改。
- 我在代码里加了特征的L2归一化,这会让Triplet Loss的训练更稳定,因为特征都在单位球面上,距离计算更直观。
- 如果你的任务是两两样本的匹配(不是三元组的任务),也可以考虑用ContrastiveLoss,不过Triplet Loss也完全能用,看你数据准备的情况。
备注:内容来源于stack exchange,提问作者Mitutoyo




