最近更新时间:2024.04.07 19:15:43
首次发布时间:2023.09.08 16:10:09
本文基于火山引擎云搜索服务 ES 和图文特征提取模型 CLIP,快速搭建一套以图搜图和以文搜图的图文检索应用。
图文检索在电商、广告、设计、搜索引擎等热门领域被广泛应用。常见的图文检索包括以图搜图和以文搜图,用户通过输入文字描述或上传图片就可以在海量的图片库中快速找到同款或者相似图片。
输入的文本描述和图片作为检索对象,分别对 image 和 text 进行特征提取,并在模型中对文本和图片建立相关联系,然后在海量图片数据库进行特征向量检索,返回与检索对象最相关的记录集合。其中特征提取部分采用 CLIP 模型,向量检索采用火山引擎云搜索服务在海量图片特征中进行快速搜索。
pip install -U sentence-transformers # 模型相关 pip install -U elasticsearch7==7.10.2 # ES 向量数据库相关 pip install -U pandas # 分析 splash 的 csv
本文选择使用 Unsplash 作为图片数据集。
def read_imgset(): path = '${下载的数据集所在路径}' documents = ['photos', 'keywords', 'collections', 'conversions', 'colors'] datasets = {} for doc in documents: files = glob.glob(path + doc + ".tsv*") subsets = [] for filename in files: # pd 分析csv df = pd.read_csv(filename, sep='\t', header=0) subsets.append(df) datasets[doc] = pd.concat(subsets, axis=0, ignore_index=True) return datasets
本文选取clip-ViT-B-32
作为以图搜图、以文搜图的模型。clip-ViT-B-32
能将图片和文字联系在一起,得到一个能同时表达图片和文字的模型。
在 ES 实例中创建一个索引(image_search),并为其配置 mappings 和 settings。
示例代码如下:
PUT image_search { "mappings": { "dynamic": "false", "properties": { "photo_id": { "type": "keyword" }, "photo_url": { "type": "keyword" }, "describe": { "type": "text" }, "photo_embedding": { "type": "knn_vector", "dimension": 512 } } }, "settings": { "index": { "refresh_interval": "60s", "number_of_shards": "3", "knn.space_type": "cosinesimil", "knn": "true", "number_of_replicas": "1" } } }
当准备好数据集、模型和索引后,您可以连接 ES 实例并将数据集 CSV 文件写入目标索引。
在 ES 实例详情页面,获取实例访问地址。
如果需要在公网环境访问 ES 实例,请提前为实例开启公网访问。相关文档,请参见开启实例公网访问。
连接实例。
# 连接云搜索实例。如果遗忘实例访问用户(admin)的密码,可以选择重置密码。 cloudSearch = CloudSearch("https://{user}:{password}@{ES_URL}", verify_certs=False, ssl_show_warn=False)
写入数据到目标索引(image_search)。
from sentence_transformers import SentenceTransformer from elasticsearch7 import Elasticsearch as CloudSearch from PIL import Image import requests import pandas as pd import glob from os.path import join # We use the original clip-ViT-B-32 for encoding images img_model = SentenceTransformer('clip-ViT-B-32') text_model = SentenceTransformer('clip-ViT-B-32-multilingual-v1') # Construct request for es def encodedataset(photo_id, photo_url, describe, image): encoded_sents = { "photo_id": photo_id, "photo_url": photo_url, "describe": describe, "photo_embedding": img_model.encode(image), } return encoded_sents # download images def load_image(url_or_path): if url_or_path.startswith("http://") or url_or_path.startswith("https://"): return Image.open(requests.get(url_or_path, stream=True).raw) else: return Image.open(url_or_path) # 从unsplash的csv文件解出图片url,然后下载图片。 # 下载完了后用model 生成embedding,并构造成ES的请求进行写入。 def get_imgset_and_bulk(): datasets = read_imgset() datasets['photos'].head() kwywords = datasets['keywords'] docs = [] #遍历CSV, 根据photo_url 去download photo。 for idx, row in datasets['photos'].iterrows(): print("Process id: ", idx) # 获取CSV 中的url。 photo_url = row["photo_image_url"] photo_id = row["photo_id"] image = load_image(photo_url) # 找到photo_id 且 suggested true 对应的图片描述。 filter = kwywords.loc[(kwywords['photo_id'] == photo_id) & (kwywords['suggested_by_user'] == 't')] text = ' '.join(set(filter['keyword'])) # 封装写入ES的请求。 one_document = encodedataset(photo_id=photo_id, photo_url=photo_url, describe=text, image=image) docs.append({"index": {}}) docs.append(one_document) if idx % 20 == 0: # 20条一组进行写入。 resp = cloudSearch.bulk(docs, index='image_search') print(resp) docs = [] return docs if __name__ == '__main__': docs = get_imgset_and_bulk() print(docs)
图片向量化,执行 knn 查询。
def extract(img): # 以图搜图 res = cloudSearch.search( body={ "size": 5, "query": {"knn": {"photo_embedding": {"vector": img_model.encode(img), "k": 5}}}, "_source": ["describe", "photo_url"], }, index="image_search", ) return res fe = FeatureExtractor() @app.route('/', methods=['GET', 'POST']) def index(): # ... # Save query image img = Image.open(file.stream) # PIL image uploaded_img_path = "static/uploaded/" + datetime.now().isoformat().replace(":", ".") + "_" + file.filename img.save(uploaded_img_path) # Run search resp = fe.extract(img) return render_template('index.html', query_path=uploaded_img_path, scores=resp['hits']['hits']) # ...
搜索一张“海豹”图片,查看返回结果。