Soumyapro's picture
uploaded_app and requirements file
fadbd1c
import streamlit as st
import torch
import cv2
import albumentations as A
import torch.nn.functional as F
import pandas as pd
import numpy as np
import pickle
from pathlib import Path
from model.clip_model import CLIPModel
st.title("Product Image to description Prediction in E-commerce")
def get_emebddings(file_path):
with open(file_path,'rb') as file:
data = pickle.load(file)
return data
# def find_text_matches(model,text_emebddings,)
embeddings_data_path = Path("./data/embeddings.pkl")
image_caption_path = Path("./data/image_details.csv")
model_path = Path('./model/best.pt')
clip_model = CLIPModel().to('cpu')
clip_model.load_state_dict(torch.load(model_path,map_location='cpu'))
embeddings = get_emebddings(embeddings_data_path)
caption_df = pd.read_csv(image_caption_path)
# print(caption_df.head())
def find_text_matches(model, text_emebddings, image_path,actual_captions,max_out=4):
item={}
image = cv2.imread(image_path)
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
transform = A.Compose([
A.Resize(224,224,always_apply=True),
A.Normalize(max_pixel_value=255.0,always_apply=True)
])
trans_image = transform(image=image)['image']
item['image'] = torch.tensor(trans_image).permute(2,0,1).float().unsqueeze(0)
#Prediction
with torch.no_grad():
image_features = model.image_encoder(item['image'].to('cpu'))
image_embeddings = model.image_projection(image_features)
image_embeddings_n = F.normalize(image_embeddings,p=2,dim=-1)
text_embeddings_n = F.normalize(text_emebddings,p=2,dim=-1)
dot_similarity = text_embeddings_n @ image_embeddings_n.T
values,indices = torch.topk(dot_similarity.T.cpu() ,k=20)
matches = [actual_captions[idx] for idx in indices[::5]]
return matches
st.subheader("Select the Image from Given files path")
images = ("./images/0108775015.jpg","./images/0120129014.jpg","./images/0187949019.jpg","./images/0203595036.jpg","./images/0212629031.jpg","./images/0212629048.jpg","./images/0237347052.jpg")
image = st.selectbox("images",images)
st.subheader("Selected Image")
st.image(image)
ok = st.button("Predict")
if ok:
# st.write("true")
st.write("Predicted Product Description")
matches = find_text_matches(clip_model,embeddings,image,caption_df['caption'].values)
for i in matches:
st.write(i)