You need to enable JavaScript to run this app.
最新活动
大模型
产品
解决方案
定价
生态与合作
支持与服务
开发者
了解我们

PyTorch中Siamese network的相似度/距离计算、分类方法咨询及代码问题排查

PyTorch中Siamese network的相似度/距离计算、分类方法咨询及代码问题排查

嗨,我来帮你理清楚Siamese网络的这些问题~ 先从你的代码问题说起,再一步步讲相似度计算和分类的方法。

一、你的代码里的小问题得先改

首先要敲黑板:你把ResNet的num_classes=5设置错啦!Triplet Loss是基于特征向量的距离来训练的,不是分类任务的logits输出。你现在forward_once返回的是5类的分类结果,这根本不是我们用来算距离的特征向量啊。

给你改个正确的版本:

import torch
import torch.nn as nn
import torchvision.models as models

class SiameseNetwork(nn.Module):
    def __init__(self, feature_dim=256) -> None:
        super().__init__()
        # 加载ResNet18,这里可以根据需求加pretrained=True用预训练权重
        self.resnet = models.resnet18(pretrained=False)
        # 拿到ResNet倒数第二层的输出维度(默认是512)
        in_features = self.resnet.fc.in_features
        # 把最后一层全连接替换成输出feature_dim维的特征层
        self.resnet.fc = nn.Linear(in_features, feature_dim)

    def forward_once(self, item):
        output = self.resnet(item)
        # 加个L2归一化,让特征落在单位球面上,对Triplet Loss训练更友好
        output = nn.functional.normalize(output, p=2, dim=1)
        return output

    def forward(self, anchor, positive, negative):
        output1 = self.forward_once(anchor)
        output2 = self.forward_once(positive)
        output3 = self.forward_once(negative)
        return output1, output2, output3

解释下:这样改了之后,forward_once返回的是归一化后的特征向量,完全适合Triplet Loss的训练和后续的距离计算。

二、相似度/距离怎么算?

训练的时候,TripletMarginLoss默认用的是欧氏距离(L2距离),你也可以通过参数指定用L1距离。到了推理阶段,常用的计算方式有这几种:

  • 欧氏距离:数值越小,两个样本越像。直接用PyTorch的API就能算:
    # feat1、feat2是两个样本的特征,shape都是[batch_size, feature_dim]
    euclidean_dist = torch.nn.functional.pairwise_distance(feat1, feat2, p=2)
    
  • 余弦相似度:取值在[-1,1]之间,越接近1说明两个样本越相似:
    cos_sim = torch.nn.functional.cosine_similarity(feat1, feat2, dim=1)
    
  • 要是你想把距离转成0-1之间的相似度,直接用1 / (1 + euclidean_dist)就行,数值越接近1,样本越相似。

三、训练好之后怎么分类?

分两种常见的任务场景说:

  1. 两样本匹配(比如判断是不是同一类)
    训练完模型后,你可以用验证集找一个合适的距离阈值。比如计算两个样本的欧氏距离,如果小于这个阈值(比如0.5),就判定它们是同一类,否则不是。这个阈值得自己调,找验证集上效果最好的那个值就行。

  2. 多分类任务(比如判断样本属于哪一类)
    你可以先给每个类算一个“中心特征”:比如把每个类的所有样本都提取特征,然后取平均值作为这个类的中心。推理的时候,把待预测样本的特征和所有类的中心算距离,选距离最小的那个类作为预测结果。
    给你个简单的代码片段参考:

    # 假设class_centers是个字典,key是类标签,value是该类的中心特征(shape: [feature_dim])
    # feat是待预测样本的特征(shape: [feature_dim])
    min_dist = float('inf')
    pred_label = -1
    for label, center in class_centers.items():
        dist = torch.norm(feat - center, p=2)  # 计算欧氏距离
        if dist < min_dist:
            min_dist = dist
            pred_label = label
    

额外的小提示

  • 你用TripletMarginLoss的时候,输入顺序是(anchor, positive, negative),和你的forward输出对应,这部分是对的,不用改。
  • 我在代码里加了特征的L2归一化,这会让Triplet Loss的训练更稳定,因为特征都在单位球面上,距离计算更直观。
  • 如果你的任务是两两样本的匹配(不是三元组的任务),也可以考虑用ContrastiveLoss,不过Triplet Loss也完全能用,看你数据准备的情况。

备注:内容来源于stack exchange,提问作者Mitutoyo

火山引擎 最新活动