ryparmar commited on
Commit
973254d
·
1 Parent(s): ea041b2

restructure into a single file

Browse files
app.py CHANGED
@@ -5,6 +5,18 @@ import os
5
  import wandb
6
  import gradio as gr
7
 
 
 
 
 
 
 
 
 
 
 
 
 
8
  from pathlib import Path
9
  from typing import Callable, Dict, List, Tuple
10
  from PIL.Image import Image
@@ -24,11 +36,76 @@ README = APP_DIR / "README.md" # path to an app readme file in HTML/markdown
24
 
25
  DEFAULT_PORT = 11700
26
 
27
- # Download image embeddings
 
 
 
 
28
  wandb.login(key=os.getenv('wandb'))
29
  api = wandb.Api()
30
- artifact = api.artifact("ryparmar/fashion-aggregator/unimoda-images:v1")
31
- artifact.download("fashion_aggregator/artifacts/img-embeddings")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
32
 
33
 
34
  def main(args):
 
5
  import wandb
6
  import gradio as gr
7
 
8
+ import zipfile
9
+ import pickle
10
+ from pathlib import Path
11
+ from typing import List, Any, Dict
12
+ from PIL import Image
13
+ from pathlib import Path
14
+
15
+ from transformers import AutoTokenizer
16
+ from sentence_transformers import SentenceTransformer, util
17
+ from multilingual_clip import pt_multilingual_clip
18
+ import torch
19
+
20
  from pathlib import Path
21
  from typing import Callable, Dict, List, Tuple
22
  from PIL.Image import Image
 
36
 
37
  DEFAULT_PORT = 11700
38
 
39
+ EMBEDDINGS_DIR = "artifacts/img-embeddings"
40
+ EMBEDDINGS_FILE = os.path.join(EMBEDDINGS_DIR, "embeddings.pkl")
41
+ RAW_PHOTOS_DIR = "artifacts/raw-photos"
42
+
43
+ # Download image embeddings and raw photos
44
  wandb.login(key=os.getenv('wandb'))
45
  api = wandb.Api()
46
+ artifact_embeddings = api.artifact("ryparmar/fashion-aggregator/unimoda-images:v1")
47
+ artifact_embeddings.download(EMBEDDINGS_DIR)
48
+ artifact_raw_photos = api.artifact("ryparmar/fashion-aggregator/unimoda-raw-images:v1")
49
+ artifact_raw_photos.download("artifacts")
50
+
51
+ with zipfile.ZipFile("artifacts/unimoda.zip", 'r') as zip_ref:
52
+ zip_ref.extractall(RAW_PHOTOS_DIR)
53
+
54
+
55
+ class TextEncoder:
56
+ """Encodes the given text"""
57
+
58
+ def __init__(self, model_path='M-CLIP/XLM-Roberta-Large-Vit-B-32'):
59
+ self.model = pt_multilingual_clip.MultilingualCLIP.from_pretrained(model_path)
60
+ self.tokenizer = AutoTokenizer.from_pretrained(model_path)
61
+
62
+ @torch.no_grad()
63
+ def encode(self, query: str) -> torch.Tensor:
64
+ """Predict/infer text embedding for a given query."""
65
+ query_emb = self.model.forward([query], self.tokenizer)
66
+ return query_emb
67
+
68
+
69
+ class ImageEnoder:
70
+ """Encodes the given image"""
71
+
72
+ def __init__(self, model_path='clip-ViT-B-32'):
73
+ self.model = SentenceTransformer(model_path)
74
+
75
+ @torch.no_grad()
76
+ def encode(self, image: Image.Image) -> torch.Tensor:
77
+ """Predict/infer text embedding for a given query."""
78
+ image_emb = self.model.encode([image], convert_to_tensor=True, show_progress_bar=False)
79
+ return image_emb
80
+
81
+
82
+ class Retriever:
83
+ """Retrieves relevant images for a given text embedding."""
84
+
85
+ def __init__(self, image_embeddings_path=None):
86
+ self.text_encoder = TextEncoder()
87
+ self.image_encoder = ImageEnoder()
88
+
89
+ with open(image_embeddings_path, 'rb') as file:
90
+ self.image_names, self.image_embeddings = pickle.load(file)
91
+ print("Images:", len(self.image_names))
92
+
93
+ @torch.no_grad()
94
+ def predict(self, text_query: str, k: int=10) -> List[Any]:
95
+ """Return top-k relevant items for a given embedding"""
96
+ query_emb = self.text_encoder.encode(text_query)
97
+ relevant_images = util.semantic_search(query_emb, self.image_embeddings, top_k=k)[0]
98
+ return relevant_images
99
+
100
+ @torch.no_grad()
101
+ def search_images(self, text_query: str, k: int=6) -> Dict[str, List[Any]]:
102
+ """Return top-k relevant images for a given embedding"""
103
+ images = self.predict(text_query, k)
104
+ paths_and_scores = {"path": [], "score": []}
105
+ for img in images:
106
+ paths_and_scores["path"].append(os.path.join(RAW_PHOTOS_DIR, self.image_names[img["corpus_id"]]))
107
+ paths_and_scores["score"].append(img["score"])
108
+ return paths_and_scores
109
 
110
 
111
  def main(args):
fashion_aggregator/__init__.py DELETED
@@ -1 +0,0 @@
1
- """Modules for creating and running a fashion aggregator."""
 
 
fashion_aggregator/fashion_aggregator.py DELETED
@@ -1,125 +0,0 @@
1
- """Detects a paragraph of text in an input image.
2
-
3
- Example usage as a script:
4
-
5
- python fashion_aggregator/fashion_aggregator.py \
6
- "Two dogs playing in the snow"
7
- """
8
- import os
9
- import argparse
10
- import pickle
11
- from pathlib import Path
12
- from typing import List, Any, Dict
13
- from PIL import Image
14
- from pathlib import Path
15
-
16
- from transformers import AutoTokenizer
17
- from sentence_transformers import SentenceTransformer, util
18
- from multilingual_clip import pt_multilingual_clip
19
- import torch
20
-
21
-
22
- STAGED_TEXT_ENCODER_MODEL_DIRNAME = Path(__file__).resolve().parent / "artifacts" / "text-encoder"
23
- STAGED_TEXT_TOKENIZER_DIRNAME = Path(__file__).resolve().parent / "artifacts" / "text-tokenizer"
24
- STAGED_IMG_ENCODER_MODEL_DIRNAME = Path(__file__).resolve().parent / "artifacts" / "img-encoder"
25
- STAGED_IMG_EMBEDDINGS_DIRNAME = Path(__file__).resolve().parent / "artifacts" / "img-embeddings"
26
- RAW_PHOTOS_DIR = Path(__file__).resolve().parent / "data" / "photos"
27
- MODEL_FILE = "model.pt"
28
- EMBEDDINGS_FILE = "embeddings.pkl"
29
-
30
-
31
- class TextEncoder:
32
- """Encodes the given text"""
33
-
34
- def __init__(self, model_path='M-CLIP/XLM-Roberta-Large-Vit-B-32'):
35
- if model_path is None:
36
- model_path = STAGED_TEXT_ENCODER_MODEL_DIRNAME / MODEL_FILE
37
- self.model = pt_multilingual_clip.MultilingualCLIP.from_pretrained(model_path)
38
- self.tokenizer = AutoTokenizer.from_pretrained(model_path)
39
-
40
- @torch.no_grad()
41
- def encode(self, query: str) -> torch.Tensor:
42
- """Predict/infer text embedding for a given query."""
43
- query_emb = query_emb = self.model.forward([query], self.tokenizer)
44
- return query_emb
45
-
46
-
47
- class ImageEnoder:
48
- """Encodes the given image"""
49
-
50
- def __init__(self, model_path='clip-ViT-B-32'):
51
- if model_path is None:
52
- model_path = STAGED_IMG_ENCODER_MODEL_DIRNAME / MODEL_FILE
53
- self.model = SentenceTransformer(model_path)
54
-
55
- @torch.no_grad()
56
- def encode(self, image: Image.Image) -> torch.Tensor:
57
- """Predict/infer text embedding for a given query."""
58
- image_emb = self.model.encode([image], convert_to_tensor=True, show_progress_bar=False)
59
- return image_emb
60
-
61
-
62
- class Retriever:
63
- """Retrieves relevant images for a given text embedding."""
64
-
65
- def __init__(self, image_embeddings_path=None):
66
- if image_embeddings_path is None:
67
- image_embeddings_path = STAGED_IMG_EMBEDDINGS_DIRNAME / EMBEDDINGS_FILE
68
-
69
- self.text_encoder = TextEncoder()
70
- self.image_encoder = ImageEnoder()
71
-
72
- with open(image_embeddings_path, 'rb') as file:
73
- self.image_names, self.image_embeddings = pickle.load(file)
74
- print("Images:", len(self.image_names))
75
-
76
- @torch.no_grad()
77
- def predict(self, text_query: str, k: int=10) -> List[Any]:
78
- """Return top-k relevant items for a given embedding"""
79
- query_emb = self.text_encoder.encode(text_query)
80
- relevant_images = util.semantic_search(query_emb, self.image_embeddings, top_k=k)[0]
81
- return relevant_images
82
-
83
- @torch.no_grad()
84
- def search_images(self, text_query: str, k: int=6) -> Dict[str, List[Any]]:
85
- """Return top-k relevant images for a given embedding"""
86
- images = self.predict(text_query, k)
87
- paths_and_scores = {"path": [], "score": []}
88
- for img in images:
89
- paths_and_scores["path"].append(os.path.join(RAW_PHOTOS_DIR, self.image_names[img["corpus_id"]]))
90
- paths_and_scores["score"].append(img["score"])
91
- return paths_and_scores
92
-
93
- @torch.no_grad()
94
- def save(self, output_dir: str = None):
95
- if output_dir:
96
- Path(output_dir).mkdir(parents=True, exist_ok=True)
97
- text_encoder_path = Path(output_dir) / "text-encoder"
98
- text_tokenizer_path = Path(output_dir) / "text-tokenizer"
99
- img_encoder_path = Path(output_dir) / "img-encoder"
100
-
101
- text_encoder_path.mkdir(parents=True, exist_ok=True)
102
- text_tokenizer_path.mkdir(parents=True, exist_ok=True)
103
- img_encoder_path.mkdir(parents=True, exist_ok=True)
104
- else:
105
- Path(STAGED_TEXT_ENCODER_MODEL_DIRNAME).mkdir(parents=True, exist_ok=True)
106
- Path(STAGED_TEXT_TOKENIZER_DIRNAME).mkdir(parents=True, exist_ok=True)
107
- Path(STAGED_IMG_ENCODER_MODEL_DIRNAME).mkdir(parents=True, exist_ok=True)
108
-
109
-
110
- def main():
111
- parser = argparse.ArgumentParser(description=__doc__.split("\n")[0])
112
- parser.add_argument(
113
- "text_query",
114
- type=str,
115
- help="Text query",
116
- )
117
- args = parser.parse_args()
118
-
119
- retriever = Retriever()
120
- print(f"Given query: {args.text_query}")
121
- print(retriever.predict(args.text_query))
122
-
123
-
124
- if __name__ == "__main__":
125
- main()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fashion_aggregator/util.py DELETED
File without changes