Spaces:
Runtime error
Runtime error
File size: 5,417 Bytes
86d1dd3 72b3814 86d1dd3 72b3814 86d1dd3 749cd15 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 |
import pandas as pd
import time
from zipfile import zipfile
with zipfile('reverse_image_search.zip', 'r') as zip:
# printing all the contents of the zip file
# extracting all the files
print('Extracting all the files now...')
zip.extractall()
print('Done!')
df = pd.read_csv('reverse_image_search.csv')
df.head()
import cv2
from towhee.types.image import Image
id_img = df.set_index('id')['path'].to_dict()
def read_images(results):
imgs = []
for re in results:
path = id_img[re.id]
imgs.append(Image(cv2.imread(path), 'BGR'))
return imgs
time.sleep(60)
from milvus import default_server
from pymilvus import connections, utility
default_server.start()
time.sleep(60)
connections.connect(host='127.0.0.1', port=default_server.listen_port)
time.sleep(60)
default_server.listen_port
time.sleep(20)
print(utility.get_server_version())
from pymilvus import connections, FieldSchema, CollectionSchema, DataType, Collection, utility
def create_milvus_collection(collection_name, dim):
connections.connect(host='127.0.0.1', port='19530')
if utility.has_collection(collection_name):
utility.drop_collection(collection_name)
fields = [
FieldSchema(name='id', dtype=DataType.INT64, descrition='ids', is_primary=True, auto_id=False),
FieldSchema(name='embedding', dtype=DataType.FLOAT_VECTOR, descrition='embedding vectors', dim=dim)
]
schema = CollectionSchema(fields=fields, description='text image search')
collection = Collection(name=collection_name, schema=schema)
# create IVF_FLAT index for collection.
index_params = {
'metric_type':'L2',
'index_type':"IVF_FLAT",
'params':{"nlist":512}
}
collection.create_index(field_name="embedding", index_params=index_params)
return collection
collection = create_milvus_collection('text_image_search', 512)
from towhee import ops, pipe, DataCollection
import numpy as np
###. This section needs to have the teddy.png in the folder. Else it will throw an error.
p = (
pipe.input('path')
.map('path', 'img', ops.image_decode.cv2('rgb'))
.map('img', 'vec', ops.image_text_embedding.clip(model_name='clip_vit_base_patch16', modality='image'))
.map('vec', 'vec', lambda x: x / np.linalg.norm(x))
.output('img', 'vec')
)
DataCollection(p('./teddy.png')).show()
p2 = (
pipe.input('text')
.map('text', 'vec', ops.image_text_embedding.clip(model_name='clip_vit_base_patch16', modality='text'))
.map('vec', 'vec', lambda x: x / np.linalg.norm(x))
.output('text', 'vec')
)
DataCollection(p2("A teddybear on a skateboard in Times Square.")).show()
time.sleep(60)
collection = create_milvus_collection('text_image_search', 512)
def read_csv(csv_path, encoding='utf-8-sig'):
import csv
with open(csv_path, 'r', encoding=encoding) as f:
data = csv.DictReader(f)
for line in data:
yield int(line['id']), line['path']
p3 = (
pipe.input('csv_file')
.flat_map('csv_file', ('id', 'path'), read_csv)
.map('path', 'img', ops.image_decode.cv2('rgb'))
.map('img', 'vec', ops.image_text_embedding.clip(model_name='clip_vit_base_patch16', modality='image'))
.map('vec', 'vec', lambda x: x / np.linalg.norm(x))
.map(('id', 'vec'), (), ops.ann_insert.milvus_client(host='127.0.0.1', port='19530', collection_name='text_image_search'))
.output()
)
ret = p3('reverse_image_search.csv')
time.sleep(120)
collection.load()
time.sleep(120)
print('Total number of inserted data is {}.'.format(collection.num_entities))
import pandas as pd
import cv2
def read_image(image_ids):
df = pd.read_csv('reverse_image_search.csv')
id_img = df.set_index('id')['path'].to_dict()
imgs = []
decode = ops.image_decode.cv2('rgb')
for image_id in image_ids:
path = id_img[image_id]
imgs.append(decode(path))
return imgs
p4 = (
pipe.input('text')
.map('text', 'vec', ops.image_text_embedding.clip(model_name='clip_vit_base_patch16', modality='text'))
.map('vec', 'vec', lambda x: x / np.linalg.norm(x))
.map('vec', 'result', ops.ann_search.milvus_client(host='127.0.0.1', port='19530', collection_name='text_image_search', limit=5))
.map('result', 'image_ids', lambda x: [item[0] for item in x])
.map('image_ids', 'images', read_image)
.output('text', 'images')
)
DataCollection(p4("A white dog")).show()
DataCollection(p4("A black dog")).show()
search_pipeline = (
pipe.input('text')
.map('text', 'vec', ops.image_text_embedding.clip(model_name='clip_vit_base_patch16', modality='text'))
.map('vec', 'vec', lambda x: x / np.linalg.norm(x))
.map('vec', 'result', ops.ann_search.milvus_client(host='127.0.0.1', port='19530', collection_name='text_image_search', limit=5))
.map('result', 'image_ids', lambda x: [item[0] for item in x])
.output('image_ids')
)
def search(text):
df = pd.read_csv('reverse_image_search.csv')
id_img = df.set_index('id')['path'].to_dict()
imgs = []
image_ids = search_pipeline(text).to_list()[0][0]
return [id_img[image_id] for image_id in image_ids]
import gradio
interface = gradio.Interface(search,
gradio.inputs.Textbox(lines=1),
[gradio.outputs.Image(type="filepath", label=None) for _ in range(5)]
)
interface.launch() |