|
from pymilvus import connections, FieldSchema, CollectionSchema, DataType, Collection, utility |
|
import towhee |
|
import os |
|
|
|
|
|
def create_milvus_collection(collection_name, dim): |
|
connections.connect( |
|
alias="default", |
|
host=os.getenv("milvus.host"), |
|
port=os.getenv("milvus.port"), |
|
user=os.getenv("milvus.user"), |
|
password=os.getenv("milvus.password") |
|
) |
|
|
|
if utility.has_collection(collection_name): |
|
utility.drop_collection(collection_name) |
|
|
|
fields = [ |
|
FieldSchema(name='path', dtype=DataType.VARCHAR, descrition='ids',max_length=100, is_primary=True, auto_id=False), |
|
FieldSchema(name='embedding', dtype=DataType.FLOAT_VECTOR, descrition='embedding vectors', dim=dim) |
|
] |
|
schema = CollectionSchema(fields=fields, description='reverse image search') |
|
collection = Collection(name=collection_name, schema=schema) |
|
|
|
|
|
index_params = { |
|
'metric_type': 'L2', |
|
'index_type': "IVF_FLAT", |
|
'params': {"nlist": 2048} |
|
} |
|
collection.create_index(field_name="embedding", index_params=index_params) |
|
return collection |
|
|
|
|
|
collection = create_milvus_collection('reverse_image_search', 2048) |
|
|
|
|
|
import pandas as pd |
|
|
|
df = pd.read_csv('reverse_image_search.csv') |
|
import cv2 |
|
from towhee._types.image import Image |
|
|
|
id_img = df.set_index('id')['path'].to_dict() |
|
|
|
dc = ( |
|
towhee.read_csv('reverse_image_search.csv') |
|
.set_parallel(3) |
|
.runas_op['id', 'id'](func=lambda x: int(x)) |
|
.image_decode['path', 'img']() |
|
.image_embedding.timm['img', 'vec'](model_name='resnet50') |
|
.tensor_normalize['vec', 'vec']() |
|
.to_milvus['path','vec'](collection=collection, batch=100) |
|
) |
|
|