File size: 2,458 Bytes
fadbd1c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
import streamlit as st
import torch
import cv2
import albumentations as A
import torch.nn.functional as F
import pandas as pd
import numpy as np
import pickle
from pathlib import Path
from model.clip_model import CLIPModel
st.title("Product Image to description Prediction in E-commerce")

def get_emebddings(file_path):
    with open(file_path,'rb') as file:
        data = pickle.load(file)

    return data


# def find_text_matches(model,text_emebddings,)


embeddings_data_path = Path("./data/embeddings.pkl")
image_caption_path = Path("./data/image_details.csv")
model_path = Path('./model/best.pt')
clip_model = CLIPModel().to('cpu')
clip_model.load_state_dict(torch.load(model_path,map_location='cpu'))
embeddings = get_emebddings(embeddings_data_path)
caption_df = pd.read_csv(image_caption_path)
# print(caption_df.head())



def find_text_matches(model, text_emebddings, image_path,actual_captions,max_out=4):
    item={}
    image = cv2.imread(image_path)
    image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
    transform = A.Compose([
                        A.Resize(224,224,always_apply=True),
                        A.Normalize(max_pixel_value=255.0,always_apply=True)
                            ])
    trans_image = transform(image=image)['image']
    item['image'] = torch.tensor(trans_image).permute(2,0,1).float().unsqueeze(0)
    
    #Prediction
    with torch.no_grad():
        image_features = model.image_encoder(item['image'].to('cpu'))
        image_embeddings = model.image_projection(image_features)
        image_embeddings_n = F.normalize(image_embeddings,p=2,dim=-1)
        text_embeddings_n = F.normalize(text_emebddings,p=2,dim=-1)
        dot_similarity = text_embeddings_n @ image_embeddings_n.T
        values,indices = torch.topk(dot_similarity.T.cpu() ,k=20)
        matches = [actual_captions[idx] for idx in indices[::5]]
    return matches



st.subheader("Select the Image from Given files path")
images = ("./images/0108775015.jpg","./images/0120129014.jpg","./images/0187949019.jpg","./images/0203595036.jpg","./images/0212629031.jpg","./images/0212629048.jpg","./images/0237347052.jpg")
image = st.selectbox("images",images)
st.subheader("Selected Image")
st.image(image)
ok = st.button("Predict")
if ok:
    # st.write("true")
    st.write("Predicted Product Description")
    matches = find_text_matches(clip_model,embeddings,image,caption_df['caption'].values)
    for i in matches:
        st.write(i)