Ali Mohsin commited on
Commit
fac18b7
·
1 Parent(s): acc496a
Files changed (5) hide show
  1. Dockerfile +2 -2
  2. app.py +96 -28
  3. data/polyvore.py +16 -7
  4. inference.py +42 -29
  5. utils/data_fetch.py +36 -129
Dockerfile CHANGED
@@ -15,10 +15,10 @@ RUN apt-get update && apt-get install -y --no-install-recommends \
15
 
16
  WORKDIR /app
17
 
18
- COPY recommendation/requirements.txt /app/requirements.txt
19
  RUN pip install --upgrade pip && pip install -r /app/requirements.txt
20
 
21
- COPY recommendation /app/
22
 
23
  EXPOSE 8000
24
  EXPOSE 7860
 
15
 
16
  WORKDIR /app
17
 
18
+ COPY requirements.txt /app/requirements.txt
19
  RUN pip install --upgrade pip && pip install -r /app/requirements.txt
20
 
21
+ COPY . /app/
22
 
23
  EXPOSE 8000
24
  EXPOSE 7860
app.py CHANGED
@@ -47,13 +47,16 @@ service = InferenceService()
47
 
48
  # Non-blocking bootstrap: fetch data, prepare splits, and train if needed in background
49
  BOOT_STATUS = "idle"
 
50
 
51
 
52
  def _background_bootstrap():
53
  global BOOT_STATUS
 
54
  try:
55
  BOOT_STATUS = "preparing-dataset"
56
  ds_root = ensure_dataset_ready()
 
57
  if not ds_root:
58
  BOOT_STATUS = "dataset-not-prepared"
59
  return
@@ -182,40 +185,104 @@ def gradio_embed(files: List[str]):
182
  return str([e.tolist() for e in embs])
183
 
184
 
185
- def gradio_compose(files: List[str], occasion: str, weather: str, num_outfits: int):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
186
  if not files:
187
- return []
188
  images = _load_images_from_files(files)
189
  if not images:
190
- return []
191
- embs = service.embed_images(images)
192
  items = [
193
- {"id": f"item_{i}", "embedding": embs[i], "category": None, "image_url": None}
194
- for i in range(len(embs))
195
  ]
196
- results = service.compose_outfits(items, context={"occasion": occasion, "weather": weather, "num_outfits": int(num_outfits)})
197
- # Render as a simple markdown summary
198
- lines = []
199
- for r in results:
200
- lines.append(f"score={r['score']:.3f}, items={r['item_ids']}")
201
- return "\n".join(lines)
202
-
203
-
204
- with gr.Blocks() as demo:
205
- gr.Markdown("# Dressify Recommendations – HF Test UI")
206
- with gr.Tab("Embed"):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
207
  inp = gr.Files(label="Upload Items (multiple images)")
208
  out = gr.Textbox(label="Embeddings (JSON)")
209
  btn = gr.Button("Compute Embeddings")
210
  btn.click(fn=gradio_embed, inputs=inp, outputs=out)
211
- with gr.Tab("Compose"):
212
- inp2 = gr.Files(label="Upload Wardrobe (multiple images)")
213
- occasion = gr.Dropdown(choices=["casual", "business", "formal", "sport"], value="casual", label="Occasion")
214
- weather = gr.Dropdown(choices=["any", "hot", "mild", "cold", "rain"], value="any", label="Weather")
215
- num_outfits = gr.Slider(minimum=1, maximum=10, step=1, value=3, label="Num outfits")
216
- out2 = gr.Textbox(label="Recommendations")
217
- btn2 = gr.Button("Generate")
218
- btn2.click(fn=gradio_compose, inputs=[inp2, occasion, weather, num_outfits], outputs=out2)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
219
  with gr.Tab("Downloads"):
220
  gr.Markdown("Download trained artifacts from models/exports")
221
  file_list = gr.JSON(label="Artifacts JSON")
@@ -241,7 +308,8 @@ with gr.Blocks() as demo:
241
 
242
 
243
  try:
244
- # Mount Gradio onto FastAPI root path
 
245
  app = gr.mount_gradio_app(app, demo, path="/")
246
  except Exception:
247
  # In case mounting fails in certain runners, we still want FastAPI to be available
@@ -257,7 +325,7 @@ except Exception:
257
 
258
 
259
  if __name__ == "__main__":
260
- # Local testing
261
- demo.launch()
262
 
263
 
 
47
 
48
  # Non-blocking bootstrap: fetch data, prepare splits, and train if needed in background
49
  BOOT_STATUS = "idle"
50
+ DATASET_ROOT: Optional[str] = None
51
 
52
 
53
  def _background_bootstrap():
54
  global BOOT_STATUS
55
+ global DATASET_ROOT
56
  try:
57
  BOOT_STATUS = "preparing-dataset"
58
  ds_root = ensure_dataset_ready()
59
+ DATASET_ROOT = ds_root
60
  if not ds_root:
61
  BOOT_STATUS = "dataset-not-prepared"
62
  return
 
185
  return str([e.tolist() for e in embs])
186
 
187
 
188
+ def _stitch_strip(imgs: List[Image.Image], height: int = 256, pad: int = 6, bg=(245, 245, 245)) -> Image.Image:
189
+ if not imgs:
190
+ return Image.new("RGB", (1, height), color=bg)
191
+ resized = []
192
+ for im in imgs:
193
+ if im.mode != "RGB":
194
+ im = im.convert("RGB")
195
+ w, h = im.size
196
+ scale = height / float(h)
197
+ nw = max(1, int(w * scale))
198
+ resized.append(im.resize((nw, height)))
199
+ total_w = sum(im.size[0] for im in resized) + pad * (len(resized) + 1)
200
+ out = Image.new("RGB", (total_w, height + 2 * pad), color=bg)
201
+ x = pad
202
+ for im in resized:
203
+ out.paste(im, (x, pad))
204
+ x += im.size[0] + pad
205
+ return out
206
+
207
+
208
+ def gradio_recommend(files: List[str], occasion: str, weather: str, num_outfits: int):
209
+ # Return stitched outfit images and a JSON with details
210
  if not files:
211
+ return [], {"error": "No files uploaded"}
212
  images = _load_images_from_files(files)
213
  if not images:
214
+ return [], {"error": "Could not load images"}
215
+ # Build items that allow on-the-fly embedding in service
216
  items = [
217
+ {"id": f"item_{i}", "image": images[i], "category": None}
218
+ for i in range(len(images))
219
  ]
220
+ res = service.compose_outfits(items, context={"occasion": occasion, "weather": weather, "num_outfits": int(num_outfits)})
221
+ # Prepare stitched previews
222
+ strips: List[Image.Image] = []
223
+ for r in res:
224
+ idxs = []
225
+ for iid in r.get("item_ids", []):
226
+ try:
227
+ idxs.append(int(str(iid).split("_")[-1]))
228
+ except Exception:
229
+ continue
230
+ imgs = [images[i] for i in idxs if 0 <= i < len(images)]
231
+ strips.append(_stitch_strip(imgs))
232
+ return strips, {"outfits": res}
233
+
234
+
235
+ with gr.Blocks(fill_height=True) as demo:
236
+ gr.Markdown("## Dressify – Outfit Recommendations\nUpload multiple item images and generate complete looks.")
237
+ with gr.Tab("Recommend"):
238
+ inp2 = gr.Files(label="Upload wardrobe images", file_types=["image"], file_count="multiple")
239
+ with gr.Row():
240
+ occasion = gr.Dropdown(choices=["casual", "business", "formal", "sport"], value="casual", label="Occasion")
241
+ weather = gr.Dropdown(choices=["any", "hot", "mild", "cold", "rain"], value="any", label="Weather")
242
+ num_outfits = gr.Slider(minimum=1, maximum=8, step=1, value=3, label="Num outfits")
243
+ out_gallery = gr.Gallery(label="Recommended Outfits", columns=1, height=320)
244
+ out_json = gr.JSON(label="Details")
245
+ btn2 = gr.Button("Generate Outfits", variant="primary")
246
+ btn2.click(fn=gradio_recommend, inputs=[inp2, occasion, weather, num_outfits], outputs=[out_gallery, out_json])
247
+ with gr.Tab("Embed (debug)"):
248
  inp = gr.Files(label="Upload Items (multiple images)")
249
  out = gr.Textbox(label="Embeddings (JSON)")
250
  btn = gr.Button("Compute Embeddings")
251
  btn.click(fn=gradio_embed, inputs=inp, outputs=out)
252
+ with gr.Tab("Train"):
253
+ gr.Markdown("Train models on Stylique/Polyvore (70/10/10 split). This runs on the Space hardware.")
254
+ epochs_res = gr.Slider(1, 50, value=10, step=1, label="ResNet epochs")
255
+ epochs_vit = gr.Slider(1, 100, value=20, step=1, label="ViT epochs")
256
+ train_log = gr.Textbox(label="Training Log", lines=10)
257
+ start_btn = gr.Button("Start Training")
258
+
259
+ def start_training(res_epochs: int, vit_epochs: int):
260
+ def _runner():
261
+ try:
262
+ import subprocess
263
+ if not DATASET_ROOT:
264
+ train_log.value = "Dataset not ready."
265
+ return
266
+ export_dir = os.getenv("EXPORT_DIR", "models/exports")
267
+ os.makedirs(export_dir, exist_ok=True)
268
+ train_log.value = "Training ResNet…\n"
269
+ subprocess.run([
270
+ "python", "train_resnet.py", "--data_root", DATASET_ROOT, "--epochs", str(res_epochs),
271
+ "--out", os.path.join(export_dir, "resnet_item_embedder.pth")
272
+ ], check=False)
273
+ train_log.value += "\nTraining ViT (triplet)…\n"
274
+ subprocess.run([
275
+ "python", "train_vit_triplet.py", "--data_root", DATASET_ROOT, "--epochs", str(vit_epochs),
276
+ "--export", os.path.join(export_dir, "vit_outfit_model.pth")
277
+ ], check=False)
278
+ service.reload_models()
279
+ train_log.value += "\nDone. Artifacts in models/exports."
280
+ except Exception as e:
281
+ train_log.value += f"\nError: {e}"
282
+ threading.Thread(target=_runner, daemon=True).start()
283
+ return "Started"
284
+
285
+ start_btn.click(fn=start_training, inputs=[epochs_res, epochs_vit], outputs=train_log)
286
  with gr.Tab("Downloads"):
287
  gr.Markdown("Download trained artifacts from models/exports")
288
  file_list = gr.JSON(label="Artifacts JSON")
 
308
 
309
 
310
  try:
311
+ # Mount Gradio onto FastAPI root path (disable SSR to avoid stray port fetches)
312
+ demo.queue()
313
  app = gr.mount_gradio_app(app, demo, path="/")
314
  except Exception:
315
  # In case mounting fails in certain runners, we still want FastAPI to be available
 
325
 
326
 
327
  if __name__ == "__main__":
328
+ # Local/Space run
329
+ demo.queue().launch(ssr_mode=False)
330
 
331
 
data/polyvore.py CHANGED
@@ -7,6 +7,7 @@ from torch.utils.data import Dataset
7
  from PIL import Image
8
 
9
  from utils.transforms import build_train_transforms
 
10
 
11
 
12
  class PolyvoreTripletDataset(Dataset):
@@ -31,11 +32,21 @@ class PolyvoreTripletDataset(Dataset):
31
  with open(triplet_path, "r") as f:
32
  self.samples: List[Dict[str, Any]] = json.load(f)
33
 
 
 
 
 
 
 
 
 
 
 
 
 
 
34
  def _load_image(self, item_id: str) -> Image.Image:
35
- # Customize if images are arranged differently
36
- img_path = os.path.join(self.root, "images", f"{item_id}.jpg")
37
- if not os.path.exists(img_path):
38
- raise FileNotFoundError(img_path)
39
  return Image.open(img_path).convert("RGB")
40
 
41
  def __len__(self) -> int:
@@ -74,9 +85,7 @@ class PolyvoreOutfitDataset(Dataset):
74
  # If metadata isn't available, we will rely on count >= 3 and let model learn; here, keep as-is.
75
 
76
  def _load_image(self, item_id: str) -> Image.Image:
77
- img_path = os.path.join(self.root, "images", f"{item_id}.jpg")
78
- if not os.path.exists(img_path):
79
- raise FileNotFoundError(img_path)
80
  return Image.open(img_path).convert("RGB")
81
 
82
  def __len__(self) -> int:
 
7
  from PIL import Image
8
 
9
  from utils.transforms import build_train_transforms
10
+ from pathlib import Path
11
 
12
 
13
  class PolyvoreTripletDataset(Dataset):
 
32
  with open(triplet_path, "r") as f:
33
  self.samples: List[Dict[str, Any]] = json.load(f)
34
 
35
+ def _find_image_path(self, item_id: str) -> str:
36
+ base = os.path.join(self.root, "images")
37
+ # direct common extensions
38
+ for ext in (".jpg", ".jpeg", ".png", ".webp"):
39
+ p = os.path.join(base, f"{item_id}{ext}")
40
+ if os.path.isfile(p):
41
+ return p
42
+ # recursive fuzzy search
43
+ for p in Path(base).rglob(f"*{item_id}*"):
44
+ if p.suffix.lower() in (".jpg", ".jpeg", ".png", ".webp"):
45
+ return str(p)
46
+ raise FileNotFoundError(f"Image for item {item_id} not found under {base}")
47
+
48
  def _load_image(self, item_id: str) -> Image.Image:
49
+ img_path = self._find_image_path(item_id)
 
 
 
50
  return Image.open(img_path).convert("RGB")
51
 
52
  def __len__(self) -> int:
 
85
  # If metadata isn't available, we will rely on count >= 3 and let model learn; here, keep as-is.
86
 
87
  def _load_image(self, item_id: str) -> Image.Image:
88
+ img_path = PolyvoreTripletDataset._find_image_path(self, item_id) # reuse logic
 
 
89
  return Image.open(img_path).convert("RGB")
90
 
91
  def __len__(self) -> int:
inference.py CHANGED
@@ -89,54 +89,67 @@ class InferenceService:
89
 
90
  @torch.inference_mode()
91
  def compose_outfits(self, items: List[Dict[str, Any]], context: Dict[str, Any]) -> List[Dict[str, Any]]:
92
- # Ensure embeddings
93
  proc_items: List[Dict[str, Any]] = []
94
  for it in items:
95
- e = it.get("embedding")
96
- if e is None and it.get("image") is not None:
97
- # Not used in Gradio path, but kept for completeness
98
  emb = self.embed_images([it["image"]])[0]
99
- elif e is None:
100
- # If missing embedding and image, skip
101
  continue
102
- else:
103
- emb = np.asarray(e, dtype=np.float32)
104
- proc_items.append({"id": it.get("id"), "embedding": emb, "category": it.get("category")})
 
 
 
105
 
106
  if len(proc_items) < 2:
107
  return []
108
 
109
- # Candidate generation: enforce minimum slots (upper, bottom, shoes, accessory) if categories provided
110
- rng = np.random.default_rng(42)
111
  num_outfits = int(context.get("num_outfits", 3))
112
- min_size, max_size = 3, 5
113
- candidates: List[List[int]] = []
114
  ids = list(range(len(proc_items)))
115
- # slot-aware sampling if categories exist
116
- def has_cat(i: int, cat_prefix: str) -> bool:
117
- c = (proc_items[i].get("category") or "").lower()
118
- return cat_prefix in c
119
 
120
- uppers = [i for i in ids if any(k in (proc_items[i].get("category") or "").lower() for k in ["top", "shirt", "tshirt", "blouse", "jacket", "hoodie"]) ]
121
- bottoms = [i for i in ids if any(k in (proc_items[i].get("category") or "").lower() for k in ["pant", "trouser", "jean", "skirt", "short"]) ]
122
- shoes = [i for i in ids if "shoe" in (proc_items[i].get("category") or "").lower()]
123
- accs = [i for i in ids if any(k in (proc_items[i].get("category") or "").lower() for k in ["watch", "belt", "ring", "bracelet", "accessor"]) ]
124
 
125
- for _ in range(num_outfits * 10):
 
 
 
 
 
 
 
126
  if uppers and bottoms and shoes and accs:
127
- subset = [rng.choice(uppers).item(), rng.choice(bottoms).item(), rng.choice(shoes).item(), rng.choice(accs).item()]
128
- # optional: add one more random
 
 
 
 
 
129
  remain = list(set(ids) - set(subset))
130
  if remain and rng.random() < 0.5:
131
- subset.append(rng.choice(remain).item())
132
  else:
133
- k = rng.integers(min_size, max_size + 1)
134
- subset = rng.choice(ids, size=int(k), replace=False).tolist()
135
  candidates.append(subset)
136
 
137
- # Score using ViT
138
  def score_subset(idx_subset: List[int]) -> float:
139
- embs = torch.tensor(np.stack([proc_items[i]["embedding"] for i in idx_subset]), dtype=torch.float32, device=self.device)
 
 
 
 
140
  embs = embs.unsqueeze(0) # (1, N, D)
141
  s = self.vit.score_compatibility(embs).item()
142
  return float(s)
 
89
 
90
  @torch.inference_mode()
91
  def compose_outfits(self, items: List[Dict[str, Any]], context: Dict[str, Any]) -> List[Dict[str, Any]]:
92
+ # 1) Ensure embeddings for each input item
93
  proc_items: List[Dict[str, Any]] = []
94
  for it in items:
95
+ emb = it.get("embedding")
96
+ if emb is None and it.get("image") is not None:
97
+ # Compute on-the-fly if image provided
98
  emb = self.embed_images([it["image"]])[0]
99
+ if emb is None:
100
+ # Skip if we cannot get an embedding
101
  continue
102
+ emb_np = np.asarray(emb, dtype=np.float32)
103
+ proc_items.append({
104
+ "id": it.get("id"),
105
+ "embedding": emb_np,
106
+ "category": it.get("category")
107
+ })
108
 
109
  if len(proc_items) < 2:
110
  return []
111
 
112
+ # 2) Candidate generation
113
+ rng = np.random.default_rng(int(context.get("seed", 42)))
114
  num_outfits = int(context.get("num_outfits", 3))
115
+ min_size, max_size = 4, 6
 
116
  ids = list(range(len(proc_items)))
 
 
 
 
117
 
118
+ # Slot-aware pools from categories (best-effort)
119
+ def cat_str(i: int) -> str:
120
+ return (proc_items[i].get("category") or "").lower()
 
121
 
122
+ uppers = [i for i in ids if any(k in cat_str(i) for k in ["top", "shirt", "tshirt", "blouse", "jacket", "hoodie"])]
123
+ bottoms = [i for i in ids if any(k in cat_str(i) for k in ["pant", "trouser", "jean", "skirt", "short"])]
124
+ shoes = [i for i in ids if any(k in cat_str(i) for k in ["shoe", "sneaker", "boot", "heel"])]
125
+ accs = [i for i in ids if any(k in cat_str(i) for k in ["watch", "belt", "ring", "bracelet", "accessor", "bag", "hat"])]
126
+
127
+ candidates: List[List[int]] = []
128
+ num_samples = max(num_outfits * 12, 24)
129
+ for _ in range(num_samples):
130
  if uppers and bottoms and shoes and accs:
131
+ subset = [
132
+ int(rng.choice(uppers)),
133
+ int(rng.choice(bottoms)),
134
+ int(rng.choice(shoes)),
135
+ int(rng.choice(accs)),
136
+ ]
137
+ # Optional: add one more random distinct item
138
  remain = list(set(ids) - set(subset))
139
  if remain and rng.random() < 0.5:
140
+ subset.append(int(rng.choice(remain)))
141
  else:
142
+ k = int(rng.integers(min_size, max_size + 1))
143
+ subset = list(map(int, rng.choice(ids, size=k, replace=False).tolist()))
144
  candidates.append(subset)
145
 
146
+ # 3) Score using ViT
147
  def score_subset(idx_subset: List[int]) -> float:
148
+ embs = torch.tensor(
149
+ np.stack([proc_items[i]["embedding"] for i in idx_subset], axis=0),
150
+ dtype=torch.float32,
151
+ device=self.device,
152
+ ) # (N, D)
153
  embs = embs.unsqueeze(0) # (1, N, D)
154
  s = self.vit.score_compatibility(embs).item()
155
  return float(s)
utils/data_fetch.py CHANGED
@@ -1,148 +1,55 @@
1
  import os
2
- import shutil
3
  import zipfile
4
  from pathlib import Path
5
- from typing import Optional, List
6
 
7
- import requests
8
 
9
- try:
10
- from huggingface_hub import snapshot_download # type: ignore
11
- except Exception: # pragma: no cover
12
- snapshot_download = None
13
 
14
- try:
15
- import kagglehub # type: ignore
16
- from kagglehub import KaggleDatasetAdapter # type: ignore
17
- except Exception: # pragma: no cover
18
- kagglehub = None
19
- KaggleDatasetAdapter = None
20
-
21
-
22
- def _download_zip(url: str, dest_dir: str) -> str:
23
- os.makedirs(dest_dir, exist_ok=True)
24
- local_zip = os.path.join(dest_dir, "dataset.zip")
25
- with requests.get(url, stream=True, timeout=60) as r:
26
- r.raise_for_status()
27
- with open(local_zip, "wb") as f:
28
- for chunk in r.iter_content(chunk_size=1024 * 1024):
29
- if chunk:
30
- f.write(chunk)
31
- with zipfile.ZipFile(local_zip, "r") as zf:
32
- zf.extractall(dest_dir)
33
- os.remove(local_zip)
34
- return dest_dir
35
-
36
-
37
- def _unzip_inner_archives(root: str) -> None:
38
- """Find and extract any zip files inside root (e.g., images.zip)."""
39
- for dirpath, _dirnames, filenames in os.walk(root):
40
- for fn in filenames:
41
- if fn.lower().endswith(".zip"):
42
- zpath = os.path.join(dirpath, fn)
43
- try:
44
- with zipfile.ZipFile(zpath, "r") as zf:
45
- zf.extractall(dirpath)
46
- # keep original zip to avoid repeated work? remove to save disk
47
- try:
48
- os.remove(zpath)
49
- except Exception:
50
- pass
51
- except Exception as e: # pragma: no cover
52
- print(f"Failed to unzip inner archive {zpath}: {e}")
53
-
54
-
55
- def _ensure_images_dir(root: str) -> None:
56
- """Ensure a stable images/ path exists under root. Create a symlink if needed."""
57
- images_root = os.path.join(root, "images")
58
- if os.path.isdir(images_root):
59
  return
60
- # Try to find a folder with many jpg/png files
61
- candidate_dirs: List[str] = []
62
- for dirpath, dirnames, filenames in os.walk(root):
63
- if dirpath == root:
64
- # skip root level files, look deeper
65
- continue
66
- img_files = [f for f in filenames if f.lower().endswith((".jpg", ".jpeg", ".png"))]
67
- if len(img_files) > 1000: # heuristic: big image folder
68
- candidate_dirs.append(dirpath)
69
- # Prefer the shallowest candidate
70
- candidate_dirs.sort(key=lambda p: len(Path(p).parts))
71
- if candidate_dirs:
72
- src = candidate_dirs[0]
73
- try:
74
- os.symlink(src, images_root)
75
- print(f"Created images symlink: {images_root} -> {src}")
76
- except Exception:
77
- # fallback: create folder and leave it empty (training will fail fast if missing)
78
- os.makedirs(images_root, exist_ok=True)
79
- else:
80
- os.makedirs(images_root, exist_ok=True)
81
 
82
 
83
  def ensure_dataset_ready() -> Optional[str]:
84
  """
85
- Ensure Polyvore dataset is present locally.
86
- Priority:
87
- 1) If POLYVORE_ROOT exists and has splits, return it
88
- 2) Try Hugging Face dataset repo (defaults to Stylique/Polyvore if not set)
89
- 3) If DATA_ZIP_URL is set, download and unzip
90
- 4) Try KaggleHub (best-effort)
91
- Returns resolved root path or None if nothing done.
92
  """
93
- root = os.getenv("POLYVORE_ROOT", "./data/Polyvore")
94
- auto_fetch = os.getenv("AUTO_FETCH_DATA", "true").lower() == "true"
95
  Path(root).mkdir(parents=True, exist_ok=True)
96
 
97
- # Already prepared?
98
- if os.path.isdir(os.path.join(root, "splits")):
99
- _unzip_inner_archives(root)
100
- _ensure_images_dir(root)
101
  return root
102
- if not auto_fetch:
103
- return None
104
 
105
- # Try HF dataset repo
106
- repo = os.getenv("HF_DATASET_REPO", "Stylique/Polyvore")
107
- if repo and snapshot_download is not None:
108
- try:
109
- snapshot_download(repo, repo_type="dataset", local_dir=root)
110
- _unzip_inner_archives(root)
111
- _ensure_images_dir(root)
112
- # If splits not provided, they'll be prepared by the caller
113
- return root
114
- except Exception as e: # pragma: no cover
115
- print(f"HF dataset download failed: {e}")
116
-
117
- # Try ZIP URL
118
- zip_url = os.getenv("DATA_ZIP_URL")
119
- if zip_url:
120
- try:
121
- _download_zip(zip_url, root)
122
- _unzip_inner_archives(root)
123
- _ensure_images_dir(root)
124
- except Exception as e: # pragma: no cover
125
- print(f"ZIP download failed: {e}")
126
- return None
127
-
128
- # Try KaggleHub (no Kaggle keys required for public datasets)
129
- if kagglehub is not None and KaggleDatasetAdapter is not None:
130
- try:
131
- # Attempt to load core file to trigger dataset download locally
132
- # User can override POLYVORE_FILE_PATH to select a specific CSV/JSON
133
- file_path = os.getenv("POLYVORE_FILE_PATH", "")
134
- kagglehub.load_dataset(
135
- KaggleDatasetAdapter.PANDAS,
136
- "dnepozitek/polyvore-outfits",
137
- file_path,
138
- )
139
- # KaggleHub stores under ~/.cache/kagglehub/datasets/<slug>/...; copy to root if needed
140
- # For simplicity, assume user will run prepare script using POLYVORE_ROOT pointing to extracted images
141
- _unzip_inner_archives(root)
142
- _ensure_images_dir(root)
143
- except Exception as e: # pragma: no cover
144
- print(f"KaggleHub download failed: {e}")
145
 
146
- return root
 
 
147
 
148
 
 
1
  import os
 
2
  import zipfile
3
  from pathlib import Path
4
+ from typing import Optional
5
 
6
+ from huggingface_hub import snapshot_download # type: ignore
7
 
 
 
 
 
8
 
9
+ def _unzip_images_if_needed(root: str) -> None:
10
+ """
11
+ If an archive like images.zip exists in the dataset root, extract it to root/images.
12
+ """
13
+ images_dir = os.path.join(root, "images")
14
+ if os.path.isdir(images_dir) and any(Path(images_dir).glob("*")):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15
  return
16
+ # Common zip names at root or subfolders
17
+ candidates = [os.path.join(root, name) for name in ("images.zip", "polyvore-images.zip", "imgs.zip")]
18
+ # Also search recursively for any *images*.zip
19
+ for p in Path(root).rglob("*images*.zip"):
20
+ candidates.append(str(p))
21
+ for zpath in candidates:
22
+ if os.path.isfile(zpath):
23
+ os.makedirs(images_dir, exist_ok=True)
24
+ with zipfile.ZipFile(zpath, "r") as zf:
25
+ zf.extractall(images_dir)
26
+ return
 
 
 
 
 
 
 
 
 
 
27
 
28
 
29
  def ensure_dataset_ready() -> Optional[str]:
30
  """
31
+ Self-contained dataset fetcher for the Polyvore dataset from Hugging Face.
32
+ - Downloads the dataset repo Stylique/Polyvore into ./data/Polyvore
33
+ - Unzips images.zip into ./data/Polyvore/images
34
+ - Returns the dataset root path
 
 
 
35
  """
36
+ root = os.path.abspath(os.path.join(os.getcwd(), "data", "Polyvore"))
 
37
  Path(root).mkdir(parents=True, exist_ok=True)
38
 
39
+ # If already present, ensure images are unzipped and return
40
+ _unzip_images_if_needed(root)
41
+ if os.path.isdir(os.path.join(root, "images")):
 
42
  return root
 
 
43
 
44
+ # Download the HF dataset snapshot into root
45
+ try:
46
+ snapshot_download("Stylique/Polyvore", repo_type="dataset", local_dir=root, local_dir_use_symlinks=False)
47
+ except Exception as e: # pragma: no cover
48
+ print(f"Failed to download Stylique/Polyvore dataset: {e}")
49
+ return None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
50
 
51
+ # Unzip images if needed
52
+ _unzip_images_if_needed(root)
53
+ return root if os.path.isdir(os.path.join(root, "images")) else None
54
 
55