texttoimg / app.py
chandrakalagowda's picture
Update app.py
931867c
raw
history blame
5.37 kB
import pandas as pd
import time
from zipfile import ZipFile
with ZipFile('reverse_image_search.zip', 'r') as zip:
# printing all the contents of the zip file
# extracting all the files
print('Extracting all the files now...')
zip.extractall()
print('Done!')
df = pd.read_csv('reverse_image_search.csv')
df.head()
import cv2
from towhee.types.image import Image
id_img = df.set_index('id')['path'].to_dict()
def read_images(results):
imgs = []
for re in results:
path = id_img[re.id]
imgs.append(Image(cv2.imread(path), 'BGR'))
return imgs
from milvus import default_server
from pymilvus import connections, utility
default_server.start()
connections.connect(host='127.0.0.1', port=default_server.listen_port)
default_server.listen_port
time.sleep(20)
print(utility.get_server_version())
from pymilvus import connections, FieldSchema, CollectionSchema, DataType, Collection, utility
def create_milvus_collection(collection_name, dim):
connections.connect(host='127.0.0.1', port='19530')
if utility.has_collection(collection_name):
utility.drop_collection(collection_name)
fields = [
FieldSchema(name='id', dtype=DataType.INT64, descrition='ids', is_primary=True, auto_id=False),
FieldSchema(name='embedding', dtype=DataType.FLOAT_VECTOR, descrition='embedding vectors', dim=dim)
]
schema = CollectionSchema(fields=fields, description='text image search')
collection = Collection(name=collection_name, schema=schema)
# create IVF_FLAT index for collection.
index_params = {
'metric_type':'L2',
'index_type':"IVF_FLAT",
'params':{"nlist":512}
}
collection.create_index(field_name="embedding", index_params=index_params)
return collection
collection = create_milvus_collection('text_image_search', 512)
from towhee import ops, pipe, DataCollection
import numpy as np
###. This section needs to have the teddy.png in the folder. Else it will throw an error.
p = (
pipe.input('path')
.map('path', 'img', ops.image_decode.cv2('rgb'))
.map('img', 'vec', ops.image_text_embedding.clip(model_name='clip_vit_base_patch16', modality='image'))
.map('vec', 'vec', lambda x: x / np.linalg.norm(x))
.output('img', 'vec')
)
DataCollection(p('./teddy.png')).show()
p2 = (
pipe.input('text')
.map('text', 'vec', ops.image_text_embedding.clip(model_name='clip_vit_base_patch16', modality='text'))
.map('vec', 'vec', lambda x: x / np.linalg.norm(x))
.output('text', 'vec')
)
DataCollection(p2("A teddybear on a skateboard in Times Square.")).show()
time.sleep(60)
collection = create_milvus_collection('text_image_search', 512)
def read_csv(csv_path, encoding='utf-8-sig'):
import csv
with open(csv_path, 'r', encoding=encoding) as f:
data = csv.DictReader(f)
for line in data:
yield int(line['id']), line['path']
p3 = (
pipe.input('csv_file')
.flat_map('csv_file', ('id', 'path'), read_csv)
.map('path', 'img', ops.image_decode.cv2('rgb'))
.map('img', 'vec', ops.image_text_embedding.clip(model_name='clip_vit_base_patch16', modality='image'))
.map('vec', 'vec', lambda x: x / np.linalg.norm(x))
.map(('id', 'vec'), (), ops.ann_insert.milvus_client(host='127.0.0.1', port='19530', collection_name='text_image_search'))
.output()
)
ret = p3('reverse_image_search.csv')
time.sleep(120)
collection.load()
time.sleep(120)
print('Total number of inserted data is {}.'.format(collection.num_entities))
import pandas as pd
import cv2
def read_image(image_ids):
df = pd.read_csv('reverse_image_search.csv')
id_img = df.set_index('id')['path'].to_dict()
imgs = []
decode = ops.image_decode.cv2('rgb')
for image_id in image_ids:
path = id_img[image_id]
imgs.append(decode(path))
return imgs
p4 = (
pipe.input('text')
.map('text', 'vec', ops.image_text_embedding.clip(model_name='clip_vit_base_patch16', modality='text'))
.map('vec', 'vec', lambda x: x / np.linalg.norm(x))
.map('vec', 'result', ops.ann_search.milvus_client(host='127.0.0.1', port='19530', collection_name='text_image_search', limit=5))
.map('result', 'image_ids', lambda x: [item[0] for item in x])
.map('image_ids', 'images', read_image)
.output('text', 'images')
)
DataCollection(p4("A white dog")).show()
DataCollection(p4("A black dog")).show()
search_pipeline = (
pipe.input('text')
.map('text', 'vec', ops.image_text_embedding.clip(model_name='clip_vit_base_patch16', modality='text'))
.map('vec', 'vec', lambda x: x / np.linalg.norm(x))
.map('vec', 'result', ops.ann_search.milvus_client(host='127.0.0.1', port='19530', collection_name='text_image_search', limit=5))
.map('result', 'image_ids', lambda x: [item[0] for item in x])
.output('image_ids')
)
def search(text):
df = pd.read_csv('reverse_image_search.csv')
id_img = df.set_index('id')['path'].to_dict()
imgs = []
image_ids = search_pipeline(text).to_list()[0][0]
return [id_img[image_id] for image_id in image_ids]
import gradio
interface = gradio.Interface(search,
gradio.inputs.Textbox(lines=1),
[gradio.outputs.Image(type="filepath", label=None) for _ in range(5)]
)
interface.launch()