Neilblaze commited on
Commit
9e3ffc4
1 Parent(s): 066ee53

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +212 -0
app.py ADDED
@@ -0,0 +1,212 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Provide a text query describing what you are looking for and get back out images with links!"""
2
+
3
+ import argparse
4
+ import logging
5
+ import os
6
+ import wandb
7
+ import gradio as gr
8
+
9
+ import zipfile
10
+ import pickle
11
+ from pathlib import Path
12
+ from typing import List, Any, Dict
13
+ from PIL import Image
14
+ from pathlib import Path
15
+
16
+ from transformers import AutoTokenizer
17
+ from sentence_transformers import SentenceTransformer, util
18
+ from multilingual_clip import pt_multilingual_clip
19
+ import torch
20
+
21
+ from pathlib import Path
22
+ from typing import Callable, Dict, List, Tuple
23
+ from PIL.Image import Image
24
+
25
+ print(__file__)
26
+
27
+ os.environ["CUDA_VISIBLE_DEVICES"] = "" # do not use GPU
28
+
29
+ logging.basicConfig(level=logging.INFO)
30
+ DEFAULT_APPLICATION_NAME = "FashGen"
31
+
32
+ APP_DIR = Path(__file__).resolve().parent # what is the directory for this application?
33
+ README = APP_DIR / "README.md" # path to an app readme file in HTML/markdown
34
+
35
+ DEFAULT_PORT = 11700
36
+
37
+ EMBEDDINGS_DIR = "artifacts/img-embeddings"
38
+ EMBEDDINGS_FILE = os.path.join(EMBEDDINGS_DIR, "embeddings.pkl")
39
+ RAW_PHOTOS_DIR = "artifacts/raw-photos"
40
+
41
+ # Download image embeddings and raw photos
42
+ wandb.login(key="4b5a23a662b20fdd61f2aeb5032cf56fdce278a4") # os.getenv('wandb')
43
+ api = wandb.Api()
44
+ artifact_embeddings = api.artifact("ryparmar/fashion-aggregator/unimoda-images:v1")
45
+ artifact_embeddings.download(EMBEDDINGS_DIR)
46
+ artifact_raw_photos = api.artifact("ryparmar/fashion-aggregator/unimoda-raw-images:v1")
47
+ artifact_raw_photos.download("artifacts")
48
+
49
+ with zipfile.ZipFile("artifacts/unimoda.zip", 'r') as zip_ref:
50
+ zip_ref.extractall(RAW_PHOTOS_DIR)
51
+
52
+
53
+ class TextEncoder:
54
+ """Encodes the given text"""
55
+
56
+ def __init__(self, model_path="M-CLIP/XLM-Roberta-Large-Vit-B-32"):
57
+ self.model = pt_multilingual_clip.MultilingualCLIP.from_pretrained(model_path)
58
+ self.tokenizer = AutoTokenizer.from_pretrained(model_path)
59
+
60
+ @torch.no_grad()
61
+ def encode(self, query: str) -> torch.Tensor:
62
+ """Predict/infer text embedding for a given query."""
63
+ query_emb = self.model.forward([query], self.tokenizer)
64
+ return query_emb
65
+
66
+
67
+ class ImageEnoder:
68
+ """Encodes the given image"""
69
+
70
+ def __init__(self, model_path="clip-ViT-B-32"):
71
+ self.model = SentenceTransformer(model_path)
72
+
73
+ @torch.no_grad()
74
+ def encode(self, image: Image) -> torch.Tensor:
75
+ """Predict/infer text embedding for a given query."""
76
+ image_emb = self.model.encode([image], convert_to_tensor=True, show_progress_bar=False)
77
+ return image_emb
78
+
79
+
80
+ class Retriever:
81
+ """Retrieves relevant images for a given text embedding."""
82
+
83
+ def __init__(self, image_embeddings_path=None):
84
+ self.text_encoder = TextEncoder()
85
+ self.image_encoder = ImageEnoder()
86
+
87
+ with open(image_embeddings_path, "rb") as file:
88
+ self.image_names, self.image_embeddings = pickle.load(file)
89
+ self.image_names = [
90
+ img_name.replace("fashion-aggregator/fashion_aggregator/data/photos/", "")
91
+ for img_name in self.image_names
92
+ ]
93
+ print("Images:", len(self.image_names))
94
+
95
+ @torch.no_grad()
96
+ def predict(self, text_query: str, k: int = 10) -> List[Any]:
97
+ """Return top-k relevant items for a given embedding"""
98
+ query_emb = self.text_encoder.encode(text_query)
99
+ relevant_images = util.semantic_search(query_emb, self.image_embeddings, top_k=k)[0]
100
+ return relevant_images
101
+
102
+ @torch.no_grad()
103
+ def search_images(self, text_query: str, k: int = 6) -> Dict[str, List[Any]]:
104
+ """Return top-k relevant images for a given embedding"""
105
+ images = self.predict(text_query, k)
106
+ paths_and_scores = {"path": [], "score": []}
107
+ for img in images:
108
+ paths_and_scores["path"].append(os.path.join(RAW_PHOTOS_DIR, self.image_names[img["corpus_id"]]))
109
+ paths_and_scores["score"].append(img["score"])
110
+ return paths_and_scores
111
+
112
+
113
+ def main(args):
114
+ predictor = PredictorBackend(url=args.model_url)
115
+ frontend = make_frontend(predictor.run, flagging=args.flagging, gantry=args.gantry, app_name=args.application)
116
+ frontend.launch(
117
+ # server_name="0.0.0.0", # make server accessible, binding all interfaces # noqa: S104
118
+ # server_port=args.port, # set a port to bind to, failing if unavailable
119
+ # share=False, # should we create a (temporary) public link on https://gradio.app?
120
+ )
121
+
122
+
123
+ def make_frontend(
124
+ fn: Callable[[Image], str], flagging: bool = False, gantry: bool = False, app_name: str = "fashion-aggregator"
125
+ ):
126
+ """Creates a gradio.Interface frontend for text to image search function."""
127
+
128
+ allow_flagging = "never"
129
+
130
+ # build a basic browser interface to a Python function
131
+ frontend = gr.Interface(
132
+ fn=fn, # which Python function are we interacting with?
133
+ outputs=gr.Gallery(label="Relevant Items"),
134
+ # what input widgets does it need? we configure an image widget
135
+ inputs=gr.components.Textbox(label="Item Description"),
136
+ title="FashGen", # what should we display at the top of the page?
137
+ description=__doc__, # what should we display just above the interface?
138
+ cache_examples=False, # should we cache those inputs for faster inference? slows down start
139
+ allow_flagging=allow_flagging, # should we show users the option to "flag" outputs?
140
+ flagging_options=["incorrect", "offensive", "other"], # what options do users have for feedback?
141
+ )
142
+ return frontend
143
+
144
+
145
+ class PredictorBackend:
146
+ """Interface to a backend that serves predictions.
147
+ To communicate with a backend accessible via a URL, provide the url kwarg.
148
+ Otherwise, runs a predictor locally.
149
+ """
150
+
151
+ def __init__(self, url=None):
152
+ if url is not None:
153
+ self.url = url
154
+ self._predict = self._predict_from_endpoint
155
+ else:
156
+ model = Retriever(image_embeddings_path=EMBEDDINGS_FILE)
157
+ self._predict = model.predict
158
+ self._search_images = model.search_images
159
+
160
+ def run(self, text: str):
161
+ pred, metrics = self._predict_with_metrics(text)
162
+ self._log_inference(pred, metrics)
163
+ return pred
164
+
165
+ def _predict_with_metrics(self, text: str) -> Tuple[List[str], Dict[str, float]]:
166
+ paths_and_scores = self._search_images(text)
167
+ metrics = {"mean_score": sum(paths_and_scores["score"]) / len(paths_and_scores["score"])}
168
+ return paths_and_scores["path"], metrics
169
+
170
+ def _log_inference(self, pred, metrics):
171
+ for key, value in metrics.items():
172
+ logging.info(f"METRIC {key} {value}")
173
+ logging.info(f"PRED >begin\n{pred}\nPRED >end")
174
+
175
+
176
+ def _make_parser():
177
+ parser = argparse.ArgumentParser(description=__doc__)
178
+ parser.add_argument(
179
+ "--model_url",
180
+ default=None,
181
+ type=str,
182
+ help="Identifies a URL to which to send image data. Data is base64-encoded, converted to a utf-8 string, and then set via a POST request as JSON with the key 'image'. Default is None, which instead sends the data to a model running locally.",
183
+ )
184
+ parser.add_argument(
185
+ "--port",
186
+ default=DEFAULT_PORT,
187
+ type=int,
188
+ help=f"Port on which to expose this server. Default is {DEFAULT_PORT}.",
189
+ )
190
+ parser.add_argument(
191
+ "--flagging",
192
+ action="store_true",
193
+ help="Pass this flag to allow users to 'flag' model behavior and provide feedback.",
194
+ )
195
+ parser.add_argument(
196
+ "--gantry",
197
+ action="store_true",
198
+ help="Pass --flagging and this flag to log user feedback to Gantry. Requires GANTRY_API_KEY to be defined as an environment variable.",
199
+ )
200
+ parser.add_argument(
201
+ "--application",
202
+ default=DEFAULT_APPLICATION_NAME,
203
+ type=str,
204
+ help=f"Name of the Gantry application to which feedback should be logged, if --gantry and --flagging are passed. Default is {DEFAULT_APPLICATION_NAME}.",
205
+ )
206
+ return parser
207
+
208
+
209
+ if __name__ == "__main__":
210
+ parser = _make_parser()
211
+ args = parser.parse_args()
212
+ main(args)