Spaces:
Runtime error
Runtime error
import pandas as pd | |
import time | |
from zipfile import ZipFile | |
with ZipFile('reverse_image_search.zip', 'r') as zip: | |
# printing all the contents of the zip file | |
# extracting all the files | |
print('Extracting all the files now...') | |
zip.extractall() | |
print('Done!') | |
df = pd.read_csv('reverse_image_search.csv') | |
df.head() | |
import cv2 | |
from towhee.types.image import Image | |
id_img = df.set_index('id')['path'].to_dict() | |
def read_images(results): | |
imgs = [] | |
for re in results: | |
path = id_img[re.id] | |
imgs.append(Image(cv2.imread(path), 'BGR')) | |
return imgs | |
from milvus import default_server | |
from pymilvus import connections, utility | |
default_server.start() | |
connections.connect(host='127.0.0.1', port=default_server.listen_port) | |
default_server.listen_port | |
time.sleep(20) | |
print(utility.get_server_version()) | |
from pymilvus import connections, FieldSchema, CollectionSchema, DataType, Collection, utility | |
def create_milvus_collection(collection_name, dim): | |
connections.connect(host='127.0.0.1', port='19530') | |
if utility.has_collection(collection_name): | |
utility.drop_collection(collection_name) | |
fields = [ | |
FieldSchema(name='id', dtype=DataType.INT64, descrition='ids', is_primary=True, auto_id=False), | |
FieldSchema(name='embedding', dtype=DataType.FLOAT_VECTOR, descrition='embedding vectors', dim=dim) | |
] | |
schema = CollectionSchema(fields=fields, description='text image search') | |
collection = Collection(name=collection_name, schema=schema) | |
# create IVF_FLAT index for collection. | |
index_params = { | |
'metric_type':'L2', | |
'index_type':"IVF_FLAT", | |
'params':{"nlist":512} | |
} | |
collection.create_index(field_name="embedding", index_params=index_params) | |
return collection | |
collection = create_milvus_collection('text_image_search', 512) | |
from towhee import ops, pipe, DataCollection | |
import numpy as np | |
###. This section needs to have the teddy.png in the folder. Else it will throw an error. | |
p = ( | |
pipe.input('path') | |
.map('path', 'img', ops.image_decode.cv2('rgb')) | |
.map('img', 'vec', ops.image_text_embedding.clip(model_name='clip_vit_base_patch16', modality='image')) | |
.map('vec', 'vec', lambda x: x / np.linalg.norm(x)) | |
.output('img', 'vec') | |
) | |
DataCollection(p('./teddy.png')).show() | |
p2 = ( | |
pipe.input('text') | |
.map('text', 'vec', ops.image_text_embedding.clip(model_name='clip_vit_base_patch16', modality='text')) | |
.map('vec', 'vec', lambda x: x / np.linalg.norm(x)) | |
.output('text', 'vec') | |
) | |
DataCollection(p2("A teddybear on a skateboard in Times Square.")).show() | |
time.sleep(60) | |
collection = create_milvus_collection('text_image_search', 512) | |
def read_csv(csv_path, encoding='utf-8-sig'): | |
import csv | |
with open(csv_path, 'r', encoding=encoding) as f: | |
data = csv.DictReader(f) | |
for line in data: | |
yield int(line['id']), line['path'] | |
p3 = ( | |
pipe.input('csv_file') | |
.flat_map('csv_file', ('id', 'path'), read_csv) | |
.map('path', 'img', ops.image_decode.cv2('rgb')) | |
.map('img', 'vec', ops.image_text_embedding.clip(model_name='clip_vit_base_patch16', modality='image')) | |
.map('vec', 'vec', lambda x: x / np.linalg.norm(x)) | |
.map(('id', 'vec'), (), ops.ann_insert.milvus_client(host='127.0.0.1', port='19530', collection_name='text_image_search')) | |
.output() | |
) | |
ret = p3('reverse_image_search.csv') | |
time.sleep(120) | |
collection.load() | |
time.sleep(120) | |
print('Total number of inserted data is {}.'.format(collection.num_entities)) | |
import pandas as pd | |
import cv2 | |
def read_image(image_ids): | |
df = pd.read_csv('reverse_image_search.csv') | |
id_img = df.set_index('id')['path'].to_dict() | |
imgs = [] | |
decode = ops.image_decode.cv2('rgb') | |
for image_id in image_ids: | |
path = id_img[image_id] | |
imgs.append(decode(path)) | |
return imgs | |
p4 = ( | |
pipe.input('text') | |
.map('text', 'vec', ops.image_text_embedding.clip(model_name='clip_vit_base_patch16', modality='text')) | |
.map('vec', 'vec', lambda x: x / np.linalg.norm(x)) | |
.map('vec', 'result', ops.ann_search.milvus_client(host='127.0.0.1', port='19530', collection_name='text_image_search', limit=5)) | |
.map('result', 'image_ids', lambda x: [item[0] for item in x]) | |
.map('image_ids', 'images', read_image) | |
.output('text', 'images') | |
) | |
DataCollection(p4("A white dog")).show() | |
DataCollection(p4("A black dog")).show() | |
search_pipeline = ( | |
pipe.input('text') | |
.map('text', 'vec', ops.image_text_embedding.clip(model_name='clip_vit_base_patch16', modality='text')) | |
.map('vec', 'vec', lambda x: x / np.linalg.norm(x)) | |
.map('vec', 'result', ops.ann_search.milvus_client(host='127.0.0.1', port='19530', collection_name='text_image_search', limit=5)) | |
.map('result', 'image_ids', lambda x: [item[0] for item in x]) | |
.output('image_ids') | |
) | |
def search(text): | |
df = pd.read_csv('reverse_image_search.csv') | |
id_img = df.set_index('id')['path'].to_dict() | |
imgs = [] | |
image_ids = search_pipeline(text).to_list()[0][0] | |
return [id_img[image_id] for image_id in image_ids] | |
import gradio | |
interface = gradio.Interface(search, | |
gradio.inputs.Textbox(lines=1), | |
[gradio.outputs.Image(type="filepath", label=None) for _ in range(5)] | |
) | |
interface.launch() |