from milvus import default_server from pymilvus import connections, utility default_server.start() import cv2 import numpy import time import csv from glob import glob from pathlib import Path from statistics import mean from towhee import pipe, ops, DataCollection from towhee.types.image import Image from pymilvus import connections, FieldSchema, CollectionSchema, DataType, Collection, utility # Towhee parameters MODEL = 'vgg16' DEVICE = None # if None, use default device (cuda is enabled if available) # Milvus parameters HOST = '127.0.0.1' PORT = '19530' TOPK = 10 DIM = 512 # dimension of embedding extracted, change with MODEL COLLECTION_NAME = 'deep_dive_image_search_' + MODEL INDEX_TYPE = 'IVF_FLAT' METRIC_TYPE = 'L2' # patterns of image paths INSERT_SRC = './train/*/*.JPEG' QUERY_SRC = './test/*/*.JPEG' to_insert = glob(INSERT_SRC) to_test = glob(QUERY_SRC) # Create milvus collection (delete first if exists) def create_milvus_collection(collection_name, dim): if utility.has_collection(collection_name): utility.drop_collection(collection_name) fields = [ FieldSchema(name='path', dtype=DataType.VARCHAR, description='path to image', max_length=500, is_primary=True, auto_id=False), FieldSchema(name='embedding', dtype=DataType.FLOAT_VECTOR, description='image embedding vectors', dim=dim) ] schema = CollectionSchema(fields=fields, description='reverse image search') collection = Collection(name=collection_name, schema=schema) index_params = { 'metric_type': METRIC_TYPE, 'index_type': INDEX_TYPE, 'params': {"nlist": 2048} } collection.create_index(field_name='embedding', index_params=index_params) return collection # Read images decoder = ops.image_decode('rgb').get_op() def read_images(img_paths): imgs = [] for p in img_paths: img = decoder(p) imgs.append(img) # imgs.append(Image(cv2.imread(p), 'RGB')) return imgs # Get ground truth def ground_truth(path): train_path = str(Path(path).parent).replace('test', 'train') return [str(Path(x).resolve()) for x in glob(train_path + '/*.JPEG')] # Calculate Average Precision def get_ap(pred: list, gt: list): ct = 0 score = 0. for i, n in enumerate(pred): if n in gt: ct += 1 score += (ct / (i + 1)) if ct == 0: ap = 0 else: ap = score / ct return ap # Embedding pipeline p_embed = ( pipe.input('img_path') .map('img_path', 'img', ops.image_decode('rgb')) .map('img', 'vec', ops.image_embedding.timm(model_name=MODEL, device=DEVICE)) .map('vec', 'vec', lambda x: x / numpy.linalg.norm(x, axis=0)) ) # Display embedding result, no need for implementation p_display = p_embed.output('img_path', 'img', 'vec') DataCollection(p_display(to_insert[0])).show() # Connect to Milvus service connections.connect(host=HOST, port=PORT) # Create collection collection = create_milvus_collection(COLLECTION_NAME, DIM) print(f'A new collection created: {COLLECTION_NAME}') # Insert data p_insert = ( p_embed.map(('img_path', 'vec'), 'mr', ops.ann_insert.milvus_client( host=HOST, port=PORT, collection_name=COLLECTION_NAME )) .output('mr') ) for img_path in to_insert: p_insert(img_path) print('Number of data inserted:', collection.num_entities) # Performance collection.load() p_search_pre = ( p_embed.map('vec', ('search_res'), ops.ann_search.milvus_client( host=HOST, port=PORT, limit=TOPK, collection_name=COLLECTION_NAME)) .map('search_res', 'pred', lambda x: [str(Path(y[0]).resolve()) for y in x]) # .output('img_path', 'pred') ) p_eval = ( p_search_pre.map('img_path', 'gt', ground_truth) .map(('pred', 'gt'), 'ap', get_ap) .output('ap') ) res = [] for img_path in to_test: ap = p_eval(img_path).get()[0] res.append(ap) mAP = mean(res) print(f'mAP@{TOPK}: {mAP}') p_search_img = ( p_search_pre.map('img_path', 'gt', ground_truth) .map(('pred', 'gt'), 'ap', get_ap) .map('pred', 'res', read_images) .output('img_path', 'img', 'res', 'ap') ) DataCollection(p_search_img('./test/Joe_Biden/Biden11.JPEG')).show() def get_max_object(img, boxes): if len(boxes) == 0: return img max_area = 0 for box in boxes: x1, y1, x2, y2 = box area = (x2-x1)*(y2-y1) if area > max_area: max_area = area max_img = img[y1:y2,x1:x2,:] return max_img p_yolo = ( pipe.input('img_path') .map('img_path', 'img', ops.image_decode('rgb')) .map('img', ('boxes', 'class', 'score'), ops.object_detection.yolov5()) .map(('img', 'boxes'), 'object', get_max_object) ) # Display embedding result, no need for implementation p_display = ( p_yolo.output('img', 'object') ) DataCollection(p_display('./test/Joe_Biden/Biden11.JPEG')).show() # Search p_search_pre_yolo = ( p_yolo.map('object', 'vec', ops.image_embedding.timm(model_name=MODEL, device=DEVICE)) .map('vec', 'vec', lambda x: x / numpy.linalg.norm(x, axis=0)) .map('vec', ('search_res'), ops.ann_search.milvus_client( host=HOST, port=PORT, limit=TOPK, collection_name=COLLECTION_NAME)) .map('search_res', 'pred', lambda x: [str(Path(y[0]).resolve()) for y in x]) # .output('img_path', 'pred') ) # Evaluate with AP p_search_img_yolo = ( p_search_pre_yolo.map('img_path', 'gt', ground_truth) .map(('pred', 'gt'), 'ap', get_ap) .map('pred', 'res', read_images) .output('img', 'object', 'res', 'ap') ) DataCollection(p_search_img_yolo('./test/Joe_Biden/Biden11.JPEG')).show() import gradio DEMO_MODEL = 'vgg16' DEMO_COLLECTION = 'deep_dive_image_search_' + DEMO_MODEL def f_search(img): p_search = ( pipe.input('img') .map('img', 'vec', ops.image_embedding.timm(model_name=DEMO_MODEL, device=DEVICE)) .map('vec', 'vec', lambda x: x / numpy.linalg.norm(x, axis=0)) .map('vec', 'search_res', ops.ann_search.milvus_client( host=HOST, port=PORT, limit=TOPK, collection_name=DEMO_COLLECTION)) .map('search_res', 'pred', lambda x: [str(Path(y[0]).resolve()) for y in x]) .output('pred') ) return p_search(img).get()[0] interface = gradio.Interface(f_search, gradio.inputs.Image(type="pil", source='upload'), [gradio.outputs.Image(type="filepath", label=None) for _ in range(TOPK)] ) interface.launch()