|
import os |
|
import numpy as np |
|
from matplotlib import rcParams |
|
import matplotlib.pyplot as plt |
|
from tensorflow.keras.models import load_model, Model |
|
from tensorflow.keras.utils import load_img, save_img, img_to_array |
|
from tensorflow.keras.applications.vgg19 import preprocess_input |
|
from tensorflow.keras.layers import GlobalAveragePooling2D |
|
from pymilvus import connections, Collection, utility |
|
from requests import get |
|
import streamlit as st |
|
import zipfile |
|
|
|
|
|
|
|
def unzip_images(): |
|
with zipfile.ZipFile("Vegetable Images.zip", 'r') as zip_ref: |
|
zip_ref.extractall('.') |
|
print('unzipped images') |
|
|
|
if not os.path.exists('Vegetable Images/'): |
|
unzip_images() |
|
|
|
|
|
class ImageVectorizer: |
|
''' |
|
Get vector representation of an image using VGG19 model fine tuned on vegetable images for classification |
|
''' |
|
|
|
def __init__(self): |
|
self.__model = self.get_model() |
|
|
|
@staticmethod |
|
@st.cache_resource |
|
def get_model(): |
|
model = load_model('vegetable_classification_model_vgg.h5') |
|
top = model.get_layer('block5_pool').output |
|
top = GlobalAveragePooling2D()(top) |
|
model = Model(inputs=model.input, outputs=top) |
|
print('loaded model') |
|
return model |
|
|
|
def vectorize(self, img_path: str): |
|
model = self.__model |
|
test_image = load_img(img_path, color_mode="rgb", target_size=(224, 224)) |
|
test_image = img_to_array(test_image) |
|
test_image = preprocess_input(test_image) |
|
test_image = np.array([test_image]) |
|
return model(test_image).numpy()[0] |
|
|
|
|
|
@st.cache_resource |
|
def get_milvus_collection(): |
|
uri = os.environ.get("URI") |
|
token = os.environ.get("TOKEN") |
|
connections.connect("default", uri=uri, token=token) |
|
print(f"Connected to DB") |
|
collection_name = os.environ.get("COLLECTION_NAME") |
|
collection = Collection(name=collection_name) |
|
collection.load() |
|
return collection |
|
|
|
|
|
def plot_images(input_image_path: str, similar_img_paths: list): |
|
|
|
rows = 5 |
|
cols = 3 |
|
fig, ax = plt.subplots(rows, cols, figsize=(12, 20)) |
|
r = 0 |
|
c = 0 |
|
for i in range(rows*cols): |
|
sim_image = load_img(similar_img_paths[i], color_mode="rgb", target_size=(224, 224)) |
|
ax[r,c].axis("off") |
|
ax[r,c].imshow(sim_image) |
|
c += 1 |
|
if c == cols: |
|
c = 0 |
|
r += 1 |
|
plt.subplots_adjust(wspace=0.01, hspace=0.01) |
|
|
|
|
|
rcParams.update({'figure.autolayout': True}) |
|
input_image = load_img(input_image_path, color_mode="rgb", target_size=(224, 224)) |
|
with placeholder.container(): |
|
st.markdown('<p style="font-size: 20px; font-weight: bold">Input image</p>', unsafe_allow_html=True) |
|
st.image(input_image) |
|
|
|
st.write(' \n') |
|
|
|
|
|
st.markdown('<p style="font-size: 20px; font-weight: bold">Similar images</p>', unsafe_allow_html=True) |
|
st.pyplot(fig) |
|
|
|
|
|
def find_similar_images(img_path: str, top_n: int=15): |
|
search_params = {"metric_type": "L2"} |
|
search_vec = vectorizer.vectorize(img_path) |
|
result = collection.search([search_vec], |
|
anns_field='image_vector', |
|
param=search_params, |
|
limit=top_n, |
|
guarantee_timestamp=1, |
|
output_fields=['image_path']) |
|
|
|
output_dict = {"input_image_path": img_path, "similar_image_paths": [hit.entity.get('image_path') for hits in result for hit in hits]} |
|
plot_images(output_dict['input_image_path'], output_dict['similar_image_paths']) |
|
|
|
|
|
def delete_file(path_: str): |
|
if os.path.exists(path_): |
|
os.remove(path_) |
|
|
|
@st.cache_resource |
|
def get_upload_path(): |
|
upload_file_path = os.path.join('.', 'uploads') |
|
if not os.path.exists(upload_file_path): |
|
os.makedirs(upload_file_path) |
|
upload_filename = "input.jpg" |
|
upload_file_path = os.path.join(upload_file_path, upload_filename) |
|
return upload_file_path |
|
|
|
def process_input_image(img_url): |
|
upload_file_path = get_upload_path() |
|
headers = {'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/102.0.0.0 Safari/537.36'} |
|
r = get(img_url, headers=headers) |
|
with open(upload_file_path, "wb") as file: |
|
file.write(r.content) |
|
return upload_file_path |
|
|
|
|
|
vectorizer = ImageVectorizer() |
|
collection = get_milvus_collection() |
|
|
|
try: |
|
st.markdown("<h3>Find Similar Vegetable Images</h3>", unsafe_allow_html=True) |
|
desc = '''<p style="font-size: 15px;">Allowed Vegetables: Broad Bean, Bitter Gourd, Bottle Gourd, |
|
Green Round Brinjal, Broccoli, Cabbage, Capsicum, Carrot, Cauliflower, Cucumber, |
|
Raw Papaya, Potato, Green Pumpkin, Radish, Tomato. |
|
</p> |
|
<p style="font-size: 13px;">Image embeddings are extracted from a fine-tuned VGG model. The model is fine-tuned on <a href="https://www.kaggle.com/datasets/misrakahmed/vegetable-image-dataset" target="_blank">images</a> clicked using a mobile phone camera. |
|
Embeddings of 20,000 vegetable images are stored in Milvus vector database. Embeddings of the input image are computed and 15 most similar images (based on L2 distance) are displayed.</p> |
|
''' |
|
st.markdown(desc, unsafe_allow_html=True) |
|
img_url = st.text_input("Paste the image URL of a vegetable and hit Enter:", "") |
|
placeholder = st.empty() |
|
if img_url: |
|
placeholder.empty() |
|
img_path = process_input_image(img_url) |
|
find_similar_images(img_path, 15) |
|
delete_file(img_path) |
|
except Exception as e: |
|
st.error(f'An unexpected error occured: \n{e}') |
|
|