You need to enable JavaScript to run this app.
导航
基于云搜索服务构建图文检索应用(以图搜图、以文搜图)
最近更新时间:2024.07.04 10:22:15首次发布时间:2023.09.08 16:10:09

本文基于火山引擎云搜索服务和图文特征提取模型 CLIP,快速搭建一套以图搜图和以文搜图的图文检索应用。

背景信息

图文检索在电商、广告、设计、搜索引擎等热门领域被广泛应用。常见的图文检索包括以图搜图和以文搜图,用户通过输入文字描述或上传图片就可以在海量的图片库中快速找到同款或者相似图片。
输入的文本描述和图片作为检索对象,分别对 image 和 text 进行特征提取,并在模型中对文本和图片建立相关联系,然后在海量图片数据库进行特征向量检索,返回与检索对象最相关的记录集合。其中特征提取部分采用 CLIP 模型,向量检索采用火山引擎云搜索服务在海量图片特征中进行快速搜索。
图片

步骤一:准备环境

  1. 登录云搜索服务控制台,然后创建一个 7.10 版本的 ES 实例。
    图片
  2. 安装 Python Client 依赖。
    pip install -U sentence-transformers # 模型相关
    pip install -U elasticsearch7==7.10.2 # ES 向量数据库相关
    pip install -U pandas # 分析 splash 的 csv
    

步骤二:准备数据集

本文选择使用 Unsplash 作为图片数据集。

  1. 登录Unsplash,并下载免费的 Lite 数据集。
    Lite 数据集包含约 25000 张照片。下载完成后会获得一个压缩文件,其中包含描述图片的 CSV 文件。
    图片
  2. 使用 Pandas 读取 CSV 文件,获得图片的 URL 地址。
    def read_imgset():
        path = '${下载的数据集所在路径}'
        documents = ['photos', 'keywords', 'collections', 'conversions', 'colors']
        datasets = {}
        
        for doc in documents:
            files = glob.glob(path + doc + ".tsv*")
            subsets = []
            for filename in files:
                # pd 分析csv
                df = pd.read_csv(filename, sep='\t', header=0)
                subsets.append(df)    
            datasets[doc] = pd.concat(subsets, axis=0, ignore_index=True)
        return datasets
    

步骤三:选择 CLIP 模型

本文选取clip-ViT-B-32作为以图搜图、以文搜图的模型。
clip-ViT-B-32能将图片和文字联系在一起,得到一个能同时表达图片和文字的模型。

步骤四:创建 ES 索引

在 ES 实例中创建一个索引(image_search),并为其配置 mappings 和 settings。
示例代码如下:

PUT image_search
{
  "mappings": {
    "dynamic": "false",
    "properties": {
      "photo_id": { "type": "keyword" },
      "photo_url": { "type": "keyword" },
      "describe": { "type": "text" },
      "photo_embedding": { "type": "knn_vector", "dimension": 512 }
    }
  },
  "settings": {
    "index": {
      "refresh_interval": "60s",
      "number_of_shards": "3",
      "knn.space_type": "cosinesimil",
      "knn": "true",
      "number_of_replicas": "1"
    }
  }
}

步骤五:写入数据

当准备好数据集、模型和索引后,您可以连接 ES 实例并将数据集 CSV 文件写入目标索引。

  1. 在 ES 实例详情页面,获取实例访问地址。
    如果需要在公网环境访问 ES 实例,请提前为实例开启公网访问。相关文档,请参见开启实例公网访问
    图片

  2. 连接实例。

    # 连接云搜索实例。如果遗忘实例访问用户(admin)的密码,可以选择重置密码。
    cloudSearch = CloudSearch("https://{user}:{password}@{ES_URL}", 
                        verify_certs=False, 
                        ssl_show_warn=False)
    
  3. 写入数据到目标索引(image_search)。

    from sentence_transformers import SentenceTransformer
    from elasticsearch7 import Elasticsearch as CloudSearch
    from PIL import Image
    import requests
    import pandas as pd
    import glob
    from os.path import join 
    
    # We use the original clip-ViT-B-32 for encoding images
    img_model = SentenceTransformer('clip-ViT-B-32')
    text_model = SentenceTransformer('clip-ViT-B-32-multilingual-v1')
    
    # Construct request for es
    def encodedataset(photo_id, photo_url, describe, image):
        encoded_sents = {
            "photo_id": photo_id,
            "photo_url": photo_url,
            "describe": describe,
            "photo_embedding": img_model.encode(image),
        }
        return encoded_sents
    
    # download images
    def load_image(url_or_path):
        if url_or_path.startswith("http://") or url_or_path.startswith("https://"):
            return Image.open(requests.get(url_or_path, stream=True).raw)
        else:
            return Image.open(url_or_path)
    
    # 从unsplash的csv文件解出图片url,然后下载图片。
    # 下载完了后用model 生成embedding,并构造成ES的请求进行写入。
    def get_imgset_and_bulk():
        datasets = read_imgset()
        datasets['photos'].head()
        kwywords = datasets['keywords']
        docs = []
        #遍历CSV, 根据photo_url 去download photo。
        for idx, row in datasets['photos'].iterrows():
            print("Process id: ", idx)
            # 获取CSV 中的url。
            photo_url = row["photo_image_url"]
            photo_id  = row["photo_id"]
            image = load_image(photo_url)
            # 找到photo_id 且 suggested true 对应的图片描述。
            filter = kwywords.loc[(kwywords['photo_id'] == photo_id) & (kwywords['suggested_by_user'] == 't')]
            text = ' '.join(set(filter['keyword']))
            # 封装写入ES的请求。
            one_document = encodedataset(photo_id=photo_id, photo_url=photo_url, describe=text, image=image)
            docs.append({"index": {}})
            docs.append(one_document)
            if idx % 20 == 0:
                # 20条一组进行写入。
                resp = cloudSearch.bulk(docs, index='image_search')
                print(resp)
                docs = []
        return docs
        
    if __name__ == '__main__':
        docs = get_imgset_and_bulk()
        print(docs)
    

结果验证

图片向量化,执行 knn 查询。

def extract(img):
    # 以图搜图
    res = cloudSearch.search(
        body={
            "size": 5,
            "query": {"knn": {"photo_embedding": {"vector": img_model.encode(img), "k": 5}}},
            "_source": ["describe", "photo_url"],
        },
        index="image_search",
    )
    return res
    
fe = FeatureExtractor()
@app.route('/', methods=['GET', 'POST'])
def index():
    # ...
    # Save query image
    img = Image.open(file.stream)  # PIL image
    uploaded_img_path = "static/uploaded/" + datetime.now().isoformat().replace(":", ".") + "_" + file.filename
    img.save(uploaded_img_path)

    # Run search
    resp = fe.extract(img)

    return render_template('index.html',
                       query_path=uploaded_img_path,
                       scores=resp['hits']['hits'])
    # ...
    
    

搜索一张“海豹”图片,查看返回结果。