nampham1106's picture
first commit
ab9b7a8
import os
from tqdm.auto import tqdm
from utils.utils import create_client
from pymilvus import Collection, DataType, FieldSchema, CollectionSchema, utility
from utils.get_embeddings import preprocess_image, extract_features, create_resnet18_model
COLLECTION_NAME = "Resnet18"
EMBEDDING_DIM = 512
IMAGE_FOLDER = "/home/nampham/Desktop/image-retrieval/data/images_mr"
client = create_client()
def load_collection():
check_collection = utility.has_collection(COLLECTION_NAME)
if check_collection:
print("Load and use collection right now!")
collection = Collection(COLLECTION_NAME)
collection.load()
print(utility.load_state(COLLECTION_NAME))
else:
print("Please create a collection and insert data!")
collection = create_collection()
# insert data into collection
model = create_resnet18_model()
insert_data(model, collection, IMAGE_FOLDER)
# create index for search
create_index(collection)
return collection
def create_collection():
image_id = FieldSchema(
name="image_id",
dtype=DataType.INT64,
is_primary=True,
description="Image ID"
)
image_embedding = FieldSchema(
name="image_embedding",
dtype=DataType.FLOAT_VECTOR,
description="Image Embedding"
)
schema = CollectionSchema(
fields=[image_id, image_embedding],
auto_id=True,
description="Image Retrieval using Resnet18"
)
collection = Collection(
name=COLLECTION_NAME,
schema=schema
)
return collection
def insert_data(model, collection, image_folder):
image_ids = sorted([
int(iamge_name.split('.')[0]) for image_name in os.listdir(image_folder)
])
image_embeddings = []
for image_name in tqdm(image_ids):
file_name = str(image_name) + ".jpg"
image_path = os.path.join(image_folder, file_name)
processed_image = preprocess_image(image_path)
processed_image = extract_features(model, processed_image)
image_embeddings.append(processed_image)
entities = [image_ids, image_embeddings]
ins_resp = collection.insert(entities)
collection.flush()
def create_index(collection):
index_params = {
"index_type": "IVF_FLAT",
"metric_type": "L2",
"params": {}
}
collection.create_index(
field_name=image_embedding.name,
index_params=index_params
)
# load collection
collection.load()