roy214 commited on
Commit
d6ba9d2
·
verified ·
1 Parent(s): 8431f5b

Update src/streamlit_app.py

Browse files
Files changed (1) hide show
  1. src/streamlit_app.py +115 -38
src/streamlit_app.py CHANGED
@@ -1,40 +1,117 @@
1
- import altair as alt
2
- import numpy as np
3
- import pandas as pd
4
  import streamlit as st
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5
 
6
- """
7
- # Welcome to Streamlit!
8
-
9
- Edit `/streamlit_app.py` to customize this app to your heart's desire :heart:.
10
- If you have any questions, checkout our [documentation](https://docs.streamlit.io) and [community
11
- forums](https://discuss.streamlit.io).
12
-
13
- In the meantime, below is an example of what you can do with just a few lines of code:
14
- """
15
-
16
- num_points = st.slider("Number of points in spiral", 1, 10000, 1100)
17
- num_turns = st.slider("Number of turns in spiral", 1, 300, 31)
18
-
19
- indices = np.linspace(0, 1, num_points)
20
- theta = 2 * np.pi * num_turns * indices
21
- radius = indices
22
-
23
- x = radius * np.cos(theta)
24
- y = radius * np.sin(theta)
25
-
26
- df = pd.DataFrame({
27
- "x": x,
28
- "y": y,
29
- "idx": indices,
30
- "rand": np.random.randn(num_points),
31
- })
32
-
33
- st.altair_chart(alt.Chart(df, height=700, width=700)
34
- .mark_point(filled=True)
35
- .encode(
36
- x=alt.X("x", axis=None),
37
- y=alt.Y("y", axis=None),
38
- color=alt.Color("idx", legend=None, scale=alt.Scale()),
39
- size=alt.Size("rand", legend=None, scale=alt.Scale(range=[1, 150])),
40
- ))
 
 
 
 
1
  import streamlit as st
2
+ import boto3
3
+ from botocore.exceptions import NoCredentialsError
4
+ from io import BytesIO
5
+ from PIL import Image
6
+ import pandas as pd
7
+ import matplotlib.pyplot as plt
8
+ from PIL import Image
9
+ import os
10
+ import faiss
11
+ import pickle
12
+ import torch
13
+ from transformers import CLIPModel, CLIPProcessor
14
+ from huggingface_hub import hf_hub_download
15
+ import json
16
+
17
+ # Lấy thông tin từ Streamlit Secrets
18
+ # aws_access_key = st.secrets["AWS_ACCESS_KEY_ID"]
19
+ # aws_secret_key = st.secrets["AWS_SECRET_ACCESS_KEY"]
20
+ # aws_access_key = 'AKIATS5GX2D62YHYRFWL'
21
+ # aws_secret_key = '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'
22
+
23
+
24
+ # Khởi tạo client S3 với thông tin cấu hình từ secrets
25
+ s3 = boto3.client(
26
+ 's3'
27
+ )
28
+
29
+ def get_image_from_s3(bucket_name, img_id):
30
+ # object_name = str(img_id) + '.jpg'
31
+
32
+ try:
33
+ # # Tải ảnh từ S3
34
+ # response = s3.get_object(Bucket=bucket_name, Key=object_name)
35
+ # img_data = response['Body'].read() # Đọc dữ liệu ảnh
36
+
37
+ # # Chuyển đổi dữ liệu ảnh thành ảnh PIL
38
+ # img = Image.open(BytesIO(img_data))
39
+ # return img
40
+
41
+ return f"https://{bucket_name}.s3.amazonaws.com/{img_id}.jpg"
42
+
43
+ except NoCredentialsError:
44
+ st.error("Credentials not available.")
45
+ return None
46
+ except Exception as e:
47
+ st.error(f"Error fetching image: {e}")
48
+ return None
49
+
50
+ def show_img(img_id, score):
51
+ # Lấy ảnh từ S3
52
+ img = get_image_from_s3(bucket_name, img_id)
53
+
54
+ if img:
55
+ img_style = style[style['id'] == img_id]
56
+
57
+ if not img_style.empty:
58
+ parts = []
59
+ parts.append(str(img_style['gender'].values[0]))
60
+ parts.append(str(img_style['masterCategory'].values[0]))
61
+ parts.append(str(img_style['subCategory'].values[0]))
62
+ parts.append(str(img_style['articleType'].values[0]))
63
+ parts.append(str(img_style['baseColour'].values[0]))
64
+ parts.append(str(img_style['year'].values[0]))
65
+ parts.append(str(img_style['usage'].values[0]))
66
+ parts.append(str(img_style['productDisplayName'].values[0]))
67
+
68
+ text = '- '.join(parts)
69
+ text += f'\n Score: {score}'
70
+ st.image(img, caption=text, use_container_width=True)
71
+
72
+
73
+
74
+ def search_faiss(model, processor, index, id_map, prompt, top_k=5, device='cpu'):
75
+ inputs = processor(text=[prompt], return_tensors='pt', padding=True).to(device)
76
+ with torch.no_grad():
77
+ txt_emb = model.get_text_features(input_ids=inputs['input_ids'], attention_mask=inputs['attention_mask'])
78
+ txt_emb = txt_emb / txt_emb.norm(p=2, dim=-1, keepdim=True)
79
+ q = txt_emb.cpu().numpy().astype('float32')
80
+
81
+ D, I = index.search(q, top_k)
82
+ return [(id_map[i], float(D[0][j])) for j, i in enumerate(I[0])]
83
+
84
+
85
+ def running(prompt, top_k=5):
86
+ results = search_faiss(
87
+ model, processor,
88
+ index, id_map,
89
+ prompt=prompt,
90
+ top_k=top_k,
91
+ )
92
+
93
+ for img_id, score in results:
94
+ show_img(img_id, score)
95
+
96
+
97
+ style = pd.read_csv('styles.csv', usecols=range(10))
98
+ bucket_name = "image-text-retrieval" # Tên bucket của bạn
99
+ your_username = 'roy214'
100
+
101
+
102
+ # Load model từ Hugging Face Hub
103
+ model = CLIPModel.from_pretrained(f"{your_username}/clip-finetuned-fashion").to("cpu").eval()
104
+ processor = CLIPProcessor.from_pretrained(f"{your_username}/clip-finetuned-fashion")
105
+
106
+ # Load FAISS index và id_map
107
+ index_path = hf_hub_download(repo_id=f"{your_username}/clip-finetuned-fashion", filename="faiss_index.bin")
108
+ mapping_path = hf_hub_download(repo_id=f"{your_username}/clip-finetuned-fashion", filename="id_map.json")
109
+
110
+ index = faiss.read_index(index_path)
111
+
112
+ with open(mapping_path, "r") as f:
113
+ id_map = json.load(f)
114
 
115
+ # show_img(59403, 19)
116
+ st.text("Enter prompt")
117
+ running("Dress Women Apparel Red", top_k=5)