ryparmar commited on
Commit
10a6818
1 Parent(s): 355ebb4

add fashion-aggregator test app

Browse files
app.py CHANGED
@@ -1,7 +1,148 @@
 
 
 
 
 
1
  import gradio as gr
2
 
3
- def greet(name):
4
- return "Hello " + name + "!!"
 
5
 
6
- iface = gr.Interface(fn=greet, inputs="text", outputs="text")
7
- iface.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Provide a text query describing what you are looking for and get back out images with links!"""
2
+ import argparse
3
+ import logging
4
+ import os
5
+ import wandb
6
  import gradio as gr
7
 
8
+ from pathlib import Path
9
+ from typing import Callable, Dict, List, Tuple
10
+ from PIL.Image import Image
11
 
12
+ print(__file__)
13
+ import fashion_aggregator.fashion_aggregator as fa
14
+
15
+
16
+ os.environ["CUDA_VISIBLE_DEVICES"] = "" # do not use GPU
17
+
18
+ logging.basicConfig(level=logging.INFO)
19
+ DEFAULT_APPLICATION_NAME = "fashion-aggregator"
20
+
21
+ APP_DIR = Path(__file__).resolve().parent # what is the directory for this application?
22
+ FAVICON = APP_DIR / "t-shirt_1f455.png" # path to a small image for display in browser tab and social media
23
+ README = APP_DIR / "README.md" # path to an app readme file in HTML/markdown
24
+
25
+ DEFAULT_PORT = 11700
26
+
27
+ # Download image embeddings
28
+ api = wandb.Api()
29
+ artifact = api.artifact("ryparmar/fashion-aggregator/unimoda-images:v0")
30
+ artifact.download("fashion_aggregator/artifacts/img-embeddings")
31
+
32
+
33
+ def main(args):
34
+ predictor = PredictorBackend(url=args.model_url)
35
+ frontend = make_frontend(predictor.run, flagging=args.flagging, gantry=args.gantry, app_name=args.application)
36
+ frontend.launch(
37
+ server_name="0.0.0.0", # make server accessible, binding all interfaces # noqa: S104
38
+ server_port=args.port, # set a port to bind to, failing if unavailable
39
+ share=True, # should we create a (temporary) public link on https://gradio.app?
40
+ favicon_path=FAVICON, # what icon should we display in the address bar?
41
+ )
42
+
43
+
44
+ def make_frontend(
45
+ fn: Callable[[Image], str], flagging: bool = False, gantry: bool = False, app_name: str = "fashion-aggregator"
46
+ ):
47
+ """Creates a gradio.Interface frontend for text to image search function."""
48
+
49
+ allow_flagging = "never"
50
+ readme = _load_readme(with_logging=allow_flagging == "manual")
51
+
52
+ # build a basic browser interface to a Python function
53
+ frontend = gr.Interface(
54
+ fn=fn, # which Python function are we interacting with?
55
+ outputs=gr.Gallery(label="Relevant Items"),
56
+ # what input widgets does it need? we configure an image widget
57
+ inputs=gr.components.Textbox(label="Item Description"),
58
+ title="📝 Text2Image 👕", # what should we display at the top of the page?
59
+ thumbnail=FAVICON, # what should we display when the link is shared, e.g. on social media?
60
+ description=__doc__, # what should we display just above the interface?
61
+ article=readme, # what long-form content should we display below the interface?
62
+ cache_examples=False, # should we cache those inputs for faster inference? slows down start
63
+ allow_flagging=allow_flagging, # should we show users the option to "flag" outputs?
64
+ flagging_options=["incorrect", "offensive", "other"], # what options do users have for feedback?
65
+ )
66
+ return frontend
67
+
68
+
69
+ class PredictorBackend:
70
+ """Interface to a backend that serves predictions.
71
+
72
+ To communicate with a backend accessible via a URL, provide the url kwarg.
73
+
74
+ Otherwise, runs a predictor locally.
75
+ """
76
+
77
+ def __init__(self, url=None):
78
+ if url is not None:
79
+ self.url = url
80
+ self._predict = self._predict_from_endpoint
81
+ else:
82
+ model = fa.Retriever()
83
+ self._predict = model.predict
84
+ self._search_images = model.search_images
85
+
86
+ def run(self, text: str):
87
+ pred, metrics = self._predict_with_metrics(text)
88
+ self._log_inference(pred, metrics)
89
+ return pred
90
+
91
+ def _predict_with_metrics(self, text: str) -> Tuple[List[str], Dict[str, float]]:
92
+ paths_and_scores = self._search_images(text)
93
+ metrics = {"mean_score": sum(paths_and_scores["score"]) / len(paths_and_scores["score"])}
94
+ return paths_and_scores["path"], metrics
95
+
96
+ def _log_inference(self, pred, metrics):
97
+ for key, value in metrics.items():
98
+ logging.info(f"METRIC {key} {value}")
99
+ logging.info(f"PRED >begin\n{pred}\nPRED >end")
100
+
101
+
102
+ def _load_readme(with_logging=False):
103
+ with open(README) as f:
104
+ lines = f.readlines()
105
+ if not with_logging:
106
+ lines = lines[: lines.index("<!-- logging content below -->\n")]
107
+
108
+ readme = "".join(lines)
109
+ return readme
110
+
111
+
112
+ def _make_parser():
113
+ parser = argparse.ArgumentParser(description=__doc__)
114
+ parser.add_argument(
115
+ "--model_url",
116
+ default=None,
117
+ type=str,
118
+ help="Identifies a URL to which to send image data. Data is base64-encoded, converted to a utf-8 string, and then set via a POST request as JSON with the key 'image'. Default is None, which instead sends the data to a model running locally.",
119
+ )
120
+ parser.add_argument(
121
+ "--port",
122
+ default=DEFAULT_PORT,
123
+ type=int,
124
+ help=f"Port on which to expose this server. Default is {DEFAULT_PORT}.",
125
+ )
126
+ parser.add_argument(
127
+ "--flagging",
128
+ action="store_true",
129
+ help="Pass this flag to allow users to 'flag' model behavior and provide feedback.",
130
+ )
131
+ parser.add_argument(
132
+ "--gantry",
133
+ action="store_true",
134
+ help="Pass --flagging and this flag to log user feedback to Gantry. Requires GANTRY_API_KEY to be defined as an environment variable.",
135
+ )
136
+ parser.add_argument(
137
+ "--application",
138
+ default=DEFAULT_APPLICATION_NAME,
139
+ type=str,
140
+ help=f"Name of the Gantry application to which feedback should be logged, if --gantry and --flagging are passed. Default is {DEFAULT_APPLICATION_NAME}.",
141
+ )
142
+ return parser
143
+
144
+
145
+ if __name__ == "__main__":
146
+ parser = _make_parser()
147
+ args = parser.parse_args()
148
+ main(args)
fashion_aggregator/__init__.py ADDED
@@ -0,0 +1 @@
 
1
+ """Modules for creating and running a fashion aggregator."""
fashion_aggregator/fashion_aggregator.py ADDED
@@ -0,0 +1,125 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Detects a paragraph of text in an input image.
2
+
3
+ Example usage as a script:
4
+
5
+ python fashion_aggregator/fashion_aggregator.py \
6
+ "Two dogs playing in the snow"
7
+ """
8
+ import os
9
+ import argparse
10
+ import pickle
11
+ from pathlib import Path
12
+ from typing import List, Any, Dict
13
+ from PIL import Image
14
+ from pathlib import Path
15
+
16
+ from transformers import AutoTokenizer
17
+ from sentence_transformers import SentenceTransformer, util
18
+ from multilingual_clip import pt_multilingual_clip
19
+ import torch
20
+
21
+
22
+ STAGED_TEXT_ENCODER_MODEL_DIRNAME = Path(__file__).resolve().parent / "artifacts" / "text-encoder"
23
+ STAGED_TEXT_TOKENIZER_DIRNAME = Path(__file__).resolve().parent / "artifacts" / "text-tokenizer"
24
+ STAGED_IMG_ENCODER_MODEL_DIRNAME = Path(__file__).resolve().parent / "artifacts" / "img-encoder"
25
+ STAGED_IMG_EMBEDDINGS_DIRNAME = Path(__file__).resolve().parent / "artifacts" / "img-embeddings"
26
+ RAW_PHOTOS_DIR = Path(__file__).resolve().parent / "data" / "photos"
27
+ MODEL_FILE = "model.pt"
28
+ EMBEDDINGS_FILE = "embeddings.pkl"
29
+
30
+
31
+ class TextEncoder:
32
+ """Encodes the given text"""
33
+
34
+ def __init__(self, model_path='M-CLIP/XLM-Roberta-Large-Vit-B-32'):
35
+ if model_path is None:
36
+ model_path = STAGED_TEXT_ENCODER_MODEL_DIRNAME / MODEL_FILE
37
+ self.model = pt_multilingual_clip.MultilingualCLIP.from_pretrained(model_path)
38
+ self.tokenizer = AutoTokenizer.from_pretrained(model_path)
39
+
40
+ @torch.no_grad()
41
+ def encode(self, query: str) -> torch.Tensor:
42
+ """Predict/infer text embedding for a given query."""
43
+ query_emb = query_emb = self.model.forward([query], self.tokenizer)
44
+ return query_emb
45
+
46
+
47
+ class ImageEnoder:
48
+ """Encodes the given image"""
49
+
50
+ def __init__(self, model_path='clip-ViT-B-32'):
51
+ if model_path is None:
52
+ model_path = STAGED_IMG_ENCODER_MODEL_DIRNAME / MODEL_FILE
53
+ self.model = SentenceTransformer(model_path)
54
+
55
+ @torch.no_grad()
56
+ def encode(self, image: Image.Image) -> torch.Tensor:
57
+ """Predict/infer text embedding for a given query."""
58
+ image_emb = self.model.encode([image], convert_to_tensor=True, show_progress_bar=False)
59
+ return image_emb
60
+
61
+
62
+ class Retriever:
63
+ """Retrieves relevant images for a given text embedding."""
64
+
65
+ def __init__(self, image_embeddings_path=None):
66
+ if image_embeddings_path is None:
67
+ image_embeddings_path = STAGED_IMG_EMBEDDINGS_DIRNAME / EMBEDDINGS_FILE
68
+
69
+ self.text_encoder = TextEncoder()
70
+ self.image_encoder = ImageEnoder()
71
+
72
+ with open(image_embeddings_path, 'rb') as file:
73
+ self.image_names, self.image_embeddings = pickle.load(file)
74
+ print("Images:", len(self.image_names))
75
+
76
+ @torch.no_grad()
77
+ def predict(self, text_query: str, k: int=10) -> List[Any]:
78
+ """Return top-k relevant items for a given embedding"""
79
+ query_emb = self.text_encoder.encode(text_query)
80
+ relevant_images = util.semantic_search(query_emb, self.image_embeddings, top_k=k)[0]
81
+ return relevant_images
82
+
83
+ @torch.no_grad()
84
+ def search_images(self, text_query: str, k: int=6) -> Dict[str, List[Any]]:
85
+ """Return top-k relevant images for a given embedding"""
86
+ images = self.predict(text_query, k)
87
+ paths_and_scores = {"path": [], "score": []}
88
+ for img in images:
89
+ paths_and_scores["path"].append(os.path.join(RAW_PHOTOS_DIR, self.image_names[img["corpus_id"]]))
90
+ paths_and_scores["score"].append(img["score"])
91
+ return paths_and_scores
92
+
93
+ @torch.no_grad()
94
+ def save(self, output_dir: str = None):
95
+ if output_dir:
96
+ Path(output_dir).mkdir(parents=True, exist_ok=True)
97
+ text_encoder_path = Path(output_dir) / "text-encoder"
98
+ text_tokenizer_path = Path(output_dir) / "text-tokenizer"
99
+ img_encoder_path = Path(output_dir) / "img-encoder"
100
+
101
+ text_encoder_path.mkdir(parents=True, exist_ok=True)
102
+ text_tokenizer_path.mkdir(parents=True, exist_ok=True)
103
+ img_encoder_path.mkdir(parents=True, exist_ok=True)
104
+ else:
105
+ Path(STAGED_TEXT_ENCODER_MODEL_DIRNAME).mkdir(parents=True, exist_ok=True)
106
+ Path(STAGED_TEXT_TOKENIZER_DIRNAME).mkdir(parents=True, exist_ok=True)
107
+ Path(STAGED_IMG_ENCODER_MODEL_DIRNAME).mkdir(parents=True, exist_ok=True)
108
+
109
+
110
+ def main():
111
+ parser = argparse.ArgumentParser(description=__doc__.split("\n")[0])
112
+ parser.add_argument(
113
+ "text_query",
114
+ type=str,
115
+ help="Text query",
116
+ )
117
+ args = parser.parse_args()
118
+
119
+ retriever = Retriever()
120
+ print(f"Given query: {args.text_query}")
121
+ print(retriever.predict(args.text_query))
122
+
123
+
124
+ if __name__ == "__main__":
125
+ main()
fashion_aggregator/util.py ADDED
File without changes
requirements.txt CHANGED
@@ -1,3 +1,4 @@
1
  sentence-transformers==2.2.2
2
  clip @ git+https://github.com/openai/CLIP.git@d50d76daa670286dd6cacf3bcd80b5e4823fc8e1
3
- multilingual-clip==1.0.10
 
1
  sentence-transformers==2.2.2
2
  clip @ git+https://github.com/openai/CLIP.git@d50d76daa670286dd6cacf3bcd80b5e4823fc8e1
3
+ multilingual-clip==1.0.10
4
+ wandb