Spaces:
Runtime error
Runtime error
restructure into a single file
Browse files- app.py +80 -3
- fashion_aggregator/__init__.py +0 -1
- fashion_aggregator/fashion_aggregator.py +0 -125
- fashion_aggregator/util.py +0 -0
app.py
CHANGED
@@ -5,6 +5,18 @@ import os
|
|
5 |
import wandb
|
6 |
import gradio as gr
|
7 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
8 |
from pathlib import Path
|
9 |
from typing import Callable, Dict, List, Tuple
|
10 |
from PIL.Image import Image
|
@@ -24,11 +36,76 @@ README = APP_DIR / "README.md" # path to an app readme file in HTML/markdown
|
|
24 |
|
25 |
DEFAULT_PORT = 11700
|
26 |
|
27 |
-
|
|
|
|
|
|
|
|
|
28 |
wandb.login(key=os.getenv('wandb'))
|
29 |
api = wandb.Api()
|
30 |
-
|
31 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
32 |
|
33 |
|
34 |
def main(args):
|
|
|
5 |
import wandb
|
6 |
import gradio as gr
|
7 |
|
8 |
+
import zipfile
|
9 |
+
import pickle
|
10 |
+
from pathlib import Path
|
11 |
+
from typing import List, Any, Dict
|
12 |
+
from PIL import Image
|
13 |
+
from pathlib import Path
|
14 |
+
|
15 |
+
from transformers import AutoTokenizer
|
16 |
+
from sentence_transformers import SentenceTransformer, util
|
17 |
+
from multilingual_clip import pt_multilingual_clip
|
18 |
+
import torch
|
19 |
+
|
20 |
from pathlib import Path
|
21 |
from typing import Callable, Dict, List, Tuple
|
22 |
from PIL.Image import Image
|
|
|
36 |
|
37 |
DEFAULT_PORT = 11700
|
38 |
|
39 |
+
EMBEDDINGS_DIR = "artifacts/img-embeddings"
|
40 |
+
EMBEDDINGS_FILE = os.path.join(EMBEDDINGS_DIR, "embeddings.pkl")
|
41 |
+
RAW_PHOTOS_DIR = "artifacts/raw-photos"
|
42 |
+
|
43 |
+
# Download image embeddings and raw photos
|
44 |
wandb.login(key=os.getenv('wandb'))
|
45 |
api = wandb.Api()
|
46 |
+
artifact_embeddings = api.artifact("ryparmar/fashion-aggregator/unimoda-images:v1")
|
47 |
+
artifact_embeddings.download(EMBEDDINGS_DIR)
|
48 |
+
artifact_raw_photos = api.artifact("ryparmar/fashion-aggregator/unimoda-raw-images:v1")
|
49 |
+
artifact_raw_photos.download("artifacts")
|
50 |
+
|
51 |
+
with zipfile.ZipFile("artifacts/unimoda.zip", 'r') as zip_ref:
|
52 |
+
zip_ref.extractall(RAW_PHOTOS_DIR)
|
53 |
+
|
54 |
+
|
55 |
+
class TextEncoder:
|
56 |
+
"""Encodes the given text"""
|
57 |
+
|
58 |
+
def __init__(self, model_path='M-CLIP/XLM-Roberta-Large-Vit-B-32'):
|
59 |
+
self.model = pt_multilingual_clip.MultilingualCLIP.from_pretrained(model_path)
|
60 |
+
self.tokenizer = AutoTokenizer.from_pretrained(model_path)
|
61 |
+
|
62 |
+
@torch.no_grad()
|
63 |
+
def encode(self, query: str) -> torch.Tensor:
|
64 |
+
"""Predict/infer text embedding for a given query."""
|
65 |
+
query_emb = self.model.forward([query], self.tokenizer)
|
66 |
+
return query_emb
|
67 |
+
|
68 |
+
|
69 |
+
class ImageEnoder:
|
70 |
+
"""Encodes the given image"""
|
71 |
+
|
72 |
+
def __init__(self, model_path='clip-ViT-B-32'):
|
73 |
+
self.model = SentenceTransformer(model_path)
|
74 |
+
|
75 |
+
@torch.no_grad()
|
76 |
+
def encode(self, image: Image.Image) -> torch.Tensor:
|
77 |
+
"""Predict/infer text embedding for a given query."""
|
78 |
+
image_emb = self.model.encode([image], convert_to_tensor=True, show_progress_bar=False)
|
79 |
+
return image_emb
|
80 |
+
|
81 |
+
|
82 |
+
class Retriever:
|
83 |
+
"""Retrieves relevant images for a given text embedding."""
|
84 |
+
|
85 |
+
def __init__(self, image_embeddings_path=None):
|
86 |
+
self.text_encoder = TextEncoder()
|
87 |
+
self.image_encoder = ImageEnoder()
|
88 |
+
|
89 |
+
with open(image_embeddings_path, 'rb') as file:
|
90 |
+
self.image_names, self.image_embeddings = pickle.load(file)
|
91 |
+
print("Images:", len(self.image_names))
|
92 |
+
|
93 |
+
@torch.no_grad()
|
94 |
+
def predict(self, text_query: str, k: int=10) -> List[Any]:
|
95 |
+
"""Return top-k relevant items for a given embedding"""
|
96 |
+
query_emb = self.text_encoder.encode(text_query)
|
97 |
+
relevant_images = util.semantic_search(query_emb, self.image_embeddings, top_k=k)[0]
|
98 |
+
return relevant_images
|
99 |
+
|
100 |
+
@torch.no_grad()
|
101 |
+
def search_images(self, text_query: str, k: int=6) -> Dict[str, List[Any]]:
|
102 |
+
"""Return top-k relevant images for a given embedding"""
|
103 |
+
images = self.predict(text_query, k)
|
104 |
+
paths_and_scores = {"path": [], "score": []}
|
105 |
+
for img in images:
|
106 |
+
paths_and_scores["path"].append(os.path.join(RAW_PHOTOS_DIR, self.image_names[img["corpus_id"]]))
|
107 |
+
paths_and_scores["score"].append(img["score"])
|
108 |
+
return paths_and_scores
|
109 |
|
110 |
|
111 |
def main(args):
|
fashion_aggregator/__init__.py
DELETED
@@ -1 +0,0 @@
|
|
1 |
-
"""Modules for creating and running a fashion aggregator."""
|
|
|
|
fashion_aggregator/fashion_aggregator.py
DELETED
@@ -1,125 +0,0 @@
|
|
1 |
-
"""Detects a paragraph of text in an input image.
|
2 |
-
|
3 |
-
Example usage as a script:
|
4 |
-
|
5 |
-
python fashion_aggregator/fashion_aggregator.py \
|
6 |
-
"Two dogs playing in the snow"
|
7 |
-
"""
|
8 |
-
import os
|
9 |
-
import argparse
|
10 |
-
import pickle
|
11 |
-
from pathlib import Path
|
12 |
-
from typing import List, Any, Dict
|
13 |
-
from PIL import Image
|
14 |
-
from pathlib import Path
|
15 |
-
|
16 |
-
from transformers import AutoTokenizer
|
17 |
-
from sentence_transformers import SentenceTransformer, util
|
18 |
-
from multilingual_clip import pt_multilingual_clip
|
19 |
-
import torch
|
20 |
-
|
21 |
-
|
22 |
-
STAGED_TEXT_ENCODER_MODEL_DIRNAME = Path(__file__).resolve().parent / "artifacts" / "text-encoder"
|
23 |
-
STAGED_TEXT_TOKENIZER_DIRNAME = Path(__file__).resolve().parent / "artifacts" / "text-tokenizer"
|
24 |
-
STAGED_IMG_ENCODER_MODEL_DIRNAME = Path(__file__).resolve().parent / "artifacts" / "img-encoder"
|
25 |
-
STAGED_IMG_EMBEDDINGS_DIRNAME = Path(__file__).resolve().parent / "artifacts" / "img-embeddings"
|
26 |
-
RAW_PHOTOS_DIR = Path(__file__).resolve().parent / "data" / "photos"
|
27 |
-
MODEL_FILE = "model.pt"
|
28 |
-
EMBEDDINGS_FILE = "embeddings.pkl"
|
29 |
-
|
30 |
-
|
31 |
-
class TextEncoder:
|
32 |
-
"""Encodes the given text"""
|
33 |
-
|
34 |
-
def __init__(self, model_path='M-CLIP/XLM-Roberta-Large-Vit-B-32'):
|
35 |
-
if model_path is None:
|
36 |
-
model_path = STAGED_TEXT_ENCODER_MODEL_DIRNAME / MODEL_FILE
|
37 |
-
self.model = pt_multilingual_clip.MultilingualCLIP.from_pretrained(model_path)
|
38 |
-
self.tokenizer = AutoTokenizer.from_pretrained(model_path)
|
39 |
-
|
40 |
-
@torch.no_grad()
|
41 |
-
def encode(self, query: str) -> torch.Tensor:
|
42 |
-
"""Predict/infer text embedding for a given query."""
|
43 |
-
query_emb = query_emb = self.model.forward([query], self.tokenizer)
|
44 |
-
return query_emb
|
45 |
-
|
46 |
-
|
47 |
-
class ImageEnoder:
|
48 |
-
"""Encodes the given image"""
|
49 |
-
|
50 |
-
def __init__(self, model_path='clip-ViT-B-32'):
|
51 |
-
if model_path is None:
|
52 |
-
model_path = STAGED_IMG_ENCODER_MODEL_DIRNAME / MODEL_FILE
|
53 |
-
self.model = SentenceTransformer(model_path)
|
54 |
-
|
55 |
-
@torch.no_grad()
|
56 |
-
def encode(self, image: Image.Image) -> torch.Tensor:
|
57 |
-
"""Predict/infer text embedding for a given query."""
|
58 |
-
image_emb = self.model.encode([image], convert_to_tensor=True, show_progress_bar=False)
|
59 |
-
return image_emb
|
60 |
-
|
61 |
-
|
62 |
-
class Retriever:
|
63 |
-
"""Retrieves relevant images for a given text embedding."""
|
64 |
-
|
65 |
-
def __init__(self, image_embeddings_path=None):
|
66 |
-
if image_embeddings_path is None:
|
67 |
-
image_embeddings_path = STAGED_IMG_EMBEDDINGS_DIRNAME / EMBEDDINGS_FILE
|
68 |
-
|
69 |
-
self.text_encoder = TextEncoder()
|
70 |
-
self.image_encoder = ImageEnoder()
|
71 |
-
|
72 |
-
with open(image_embeddings_path, 'rb') as file:
|
73 |
-
self.image_names, self.image_embeddings = pickle.load(file)
|
74 |
-
print("Images:", len(self.image_names))
|
75 |
-
|
76 |
-
@torch.no_grad()
|
77 |
-
def predict(self, text_query: str, k: int=10) -> List[Any]:
|
78 |
-
"""Return top-k relevant items for a given embedding"""
|
79 |
-
query_emb = self.text_encoder.encode(text_query)
|
80 |
-
relevant_images = util.semantic_search(query_emb, self.image_embeddings, top_k=k)[0]
|
81 |
-
return relevant_images
|
82 |
-
|
83 |
-
@torch.no_grad()
|
84 |
-
def search_images(self, text_query: str, k: int=6) -> Dict[str, List[Any]]:
|
85 |
-
"""Return top-k relevant images for a given embedding"""
|
86 |
-
images = self.predict(text_query, k)
|
87 |
-
paths_and_scores = {"path": [], "score": []}
|
88 |
-
for img in images:
|
89 |
-
paths_and_scores["path"].append(os.path.join(RAW_PHOTOS_DIR, self.image_names[img["corpus_id"]]))
|
90 |
-
paths_and_scores["score"].append(img["score"])
|
91 |
-
return paths_and_scores
|
92 |
-
|
93 |
-
@torch.no_grad()
|
94 |
-
def save(self, output_dir: str = None):
|
95 |
-
if output_dir:
|
96 |
-
Path(output_dir).mkdir(parents=True, exist_ok=True)
|
97 |
-
text_encoder_path = Path(output_dir) / "text-encoder"
|
98 |
-
text_tokenizer_path = Path(output_dir) / "text-tokenizer"
|
99 |
-
img_encoder_path = Path(output_dir) / "img-encoder"
|
100 |
-
|
101 |
-
text_encoder_path.mkdir(parents=True, exist_ok=True)
|
102 |
-
text_tokenizer_path.mkdir(parents=True, exist_ok=True)
|
103 |
-
img_encoder_path.mkdir(parents=True, exist_ok=True)
|
104 |
-
else:
|
105 |
-
Path(STAGED_TEXT_ENCODER_MODEL_DIRNAME).mkdir(parents=True, exist_ok=True)
|
106 |
-
Path(STAGED_TEXT_TOKENIZER_DIRNAME).mkdir(parents=True, exist_ok=True)
|
107 |
-
Path(STAGED_IMG_ENCODER_MODEL_DIRNAME).mkdir(parents=True, exist_ok=True)
|
108 |
-
|
109 |
-
|
110 |
-
def main():
|
111 |
-
parser = argparse.ArgumentParser(description=__doc__.split("\n")[0])
|
112 |
-
parser.add_argument(
|
113 |
-
"text_query",
|
114 |
-
type=str,
|
115 |
-
help="Text query",
|
116 |
-
)
|
117 |
-
args = parser.parse_args()
|
118 |
-
|
119 |
-
retriever = Retriever()
|
120 |
-
print(f"Given query: {args.text_query}")
|
121 |
-
print(retriever.predict(args.text_query))
|
122 |
-
|
123 |
-
|
124 |
-
if __name__ == "__main__":
|
125 |
-
main()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
fashion_aggregator/util.py
DELETED
File without changes
|