Spaces:
Paused
Paused
Ali Mohsin
commited on
Commit
·
fac18b7
1
Parent(s):
acc496a
fixes
Browse files- Dockerfile +2 -2
- app.py +96 -28
- data/polyvore.py +16 -7
- inference.py +42 -29
- utils/data_fetch.py +36 -129
Dockerfile
CHANGED
|
@@ -15,10 +15,10 @@ RUN apt-get update && apt-get install -y --no-install-recommends \
|
|
| 15 |
|
| 16 |
WORKDIR /app
|
| 17 |
|
| 18 |
-
COPY
|
| 19 |
RUN pip install --upgrade pip && pip install -r /app/requirements.txt
|
| 20 |
|
| 21 |
-
COPY
|
| 22 |
|
| 23 |
EXPOSE 8000
|
| 24 |
EXPOSE 7860
|
|
|
|
| 15 |
|
| 16 |
WORKDIR /app
|
| 17 |
|
| 18 |
+
COPY requirements.txt /app/requirements.txt
|
| 19 |
RUN pip install --upgrade pip && pip install -r /app/requirements.txt
|
| 20 |
|
| 21 |
+
COPY . /app/
|
| 22 |
|
| 23 |
EXPOSE 8000
|
| 24 |
EXPOSE 7860
|
app.py
CHANGED
|
@@ -47,13 +47,16 @@ service = InferenceService()
|
|
| 47 |
|
| 48 |
# Non-blocking bootstrap: fetch data, prepare splits, and train if needed in background
|
| 49 |
BOOT_STATUS = "idle"
|
|
|
|
| 50 |
|
| 51 |
|
| 52 |
def _background_bootstrap():
|
| 53 |
global BOOT_STATUS
|
|
|
|
| 54 |
try:
|
| 55 |
BOOT_STATUS = "preparing-dataset"
|
| 56 |
ds_root = ensure_dataset_ready()
|
|
|
|
| 57 |
if not ds_root:
|
| 58 |
BOOT_STATUS = "dataset-not-prepared"
|
| 59 |
return
|
|
@@ -182,40 +185,104 @@ def gradio_embed(files: List[str]):
|
|
| 182 |
return str([e.tolist() for e in embs])
|
| 183 |
|
| 184 |
|
| 185 |
-
def
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 186 |
if not files:
|
| 187 |
-
return []
|
| 188 |
images = _load_images_from_files(files)
|
| 189 |
if not images:
|
| 190 |
-
return []
|
| 191 |
-
|
| 192 |
items = [
|
| 193 |
-
{"id": f"item_{i}", "
|
| 194 |
-
for i in range(len(
|
| 195 |
]
|
| 196 |
-
|
| 197 |
-
#
|
| 198 |
-
|
| 199 |
-
for r in
|
| 200 |
-
|
| 201 |
-
|
| 202 |
-
|
| 203 |
-
|
| 204 |
-
|
| 205 |
-
|
| 206 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 207 |
inp = gr.Files(label="Upload Items (multiple images)")
|
| 208 |
out = gr.Textbox(label="Embeddings (JSON)")
|
| 209 |
btn = gr.Button("Compute Embeddings")
|
| 210 |
btn.click(fn=gradio_embed, inputs=inp, outputs=out)
|
| 211 |
-
with gr.Tab("
|
| 212 |
-
|
| 213 |
-
|
| 214 |
-
|
| 215 |
-
|
| 216 |
-
|
| 217 |
-
|
| 218 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 219 |
with gr.Tab("Downloads"):
|
| 220 |
gr.Markdown("Download trained artifacts from models/exports")
|
| 221 |
file_list = gr.JSON(label="Artifacts JSON")
|
|
@@ -241,7 +308,8 @@ with gr.Blocks() as demo:
|
|
| 241 |
|
| 242 |
|
| 243 |
try:
|
| 244 |
-
# Mount Gradio onto FastAPI root path
|
|
|
|
| 245 |
app = gr.mount_gradio_app(app, demo, path="/")
|
| 246 |
except Exception:
|
| 247 |
# In case mounting fails in certain runners, we still want FastAPI to be available
|
|
@@ -257,7 +325,7 @@ except Exception:
|
|
| 257 |
|
| 258 |
|
| 259 |
if __name__ == "__main__":
|
| 260 |
-
# Local
|
| 261 |
-
demo.launch()
|
| 262 |
|
| 263 |
|
|
|
|
| 47 |
|
| 48 |
# Non-blocking bootstrap: fetch data, prepare splits, and train if needed in background
|
| 49 |
BOOT_STATUS = "idle"
|
| 50 |
+
DATASET_ROOT: Optional[str] = None
|
| 51 |
|
| 52 |
|
| 53 |
def _background_bootstrap():
|
| 54 |
global BOOT_STATUS
|
| 55 |
+
global DATASET_ROOT
|
| 56 |
try:
|
| 57 |
BOOT_STATUS = "preparing-dataset"
|
| 58 |
ds_root = ensure_dataset_ready()
|
| 59 |
+
DATASET_ROOT = ds_root
|
| 60 |
if not ds_root:
|
| 61 |
BOOT_STATUS = "dataset-not-prepared"
|
| 62 |
return
|
|
|
|
| 185 |
return str([e.tolist() for e in embs])
|
| 186 |
|
| 187 |
|
| 188 |
+
def _stitch_strip(imgs: List[Image.Image], height: int = 256, pad: int = 6, bg=(245, 245, 245)) -> Image.Image:
|
| 189 |
+
if not imgs:
|
| 190 |
+
return Image.new("RGB", (1, height), color=bg)
|
| 191 |
+
resized = []
|
| 192 |
+
for im in imgs:
|
| 193 |
+
if im.mode != "RGB":
|
| 194 |
+
im = im.convert("RGB")
|
| 195 |
+
w, h = im.size
|
| 196 |
+
scale = height / float(h)
|
| 197 |
+
nw = max(1, int(w * scale))
|
| 198 |
+
resized.append(im.resize((nw, height)))
|
| 199 |
+
total_w = sum(im.size[0] for im in resized) + pad * (len(resized) + 1)
|
| 200 |
+
out = Image.new("RGB", (total_w, height + 2 * pad), color=bg)
|
| 201 |
+
x = pad
|
| 202 |
+
for im in resized:
|
| 203 |
+
out.paste(im, (x, pad))
|
| 204 |
+
x += im.size[0] + pad
|
| 205 |
+
return out
|
| 206 |
+
|
| 207 |
+
|
| 208 |
+
def gradio_recommend(files: List[str], occasion: str, weather: str, num_outfits: int):
|
| 209 |
+
# Return stitched outfit images and a JSON with details
|
| 210 |
if not files:
|
| 211 |
+
return [], {"error": "No files uploaded"}
|
| 212 |
images = _load_images_from_files(files)
|
| 213 |
if not images:
|
| 214 |
+
return [], {"error": "Could not load images"}
|
| 215 |
+
# Build items that allow on-the-fly embedding in service
|
| 216 |
items = [
|
| 217 |
+
{"id": f"item_{i}", "image": images[i], "category": None}
|
| 218 |
+
for i in range(len(images))
|
| 219 |
]
|
| 220 |
+
res = service.compose_outfits(items, context={"occasion": occasion, "weather": weather, "num_outfits": int(num_outfits)})
|
| 221 |
+
# Prepare stitched previews
|
| 222 |
+
strips: List[Image.Image] = []
|
| 223 |
+
for r in res:
|
| 224 |
+
idxs = []
|
| 225 |
+
for iid in r.get("item_ids", []):
|
| 226 |
+
try:
|
| 227 |
+
idxs.append(int(str(iid).split("_")[-1]))
|
| 228 |
+
except Exception:
|
| 229 |
+
continue
|
| 230 |
+
imgs = [images[i] for i in idxs if 0 <= i < len(images)]
|
| 231 |
+
strips.append(_stitch_strip(imgs))
|
| 232 |
+
return strips, {"outfits": res}
|
| 233 |
+
|
| 234 |
+
|
| 235 |
+
with gr.Blocks(fill_height=True) as demo:
|
| 236 |
+
gr.Markdown("## Dressify – Outfit Recommendations\nUpload multiple item images and generate complete looks.")
|
| 237 |
+
with gr.Tab("Recommend"):
|
| 238 |
+
inp2 = gr.Files(label="Upload wardrobe images", file_types=["image"], file_count="multiple")
|
| 239 |
+
with gr.Row():
|
| 240 |
+
occasion = gr.Dropdown(choices=["casual", "business", "formal", "sport"], value="casual", label="Occasion")
|
| 241 |
+
weather = gr.Dropdown(choices=["any", "hot", "mild", "cold", "rain"], value="any", label="Weather")
|
| 242 |
+
num_outfits = gr.Slider(minimum=1, maximum=8, step=1, value=3, label="Num outfits")
|
| 243 |
+
out_gallery = gr.Gallery(label="Recommended Outfits", columns=1, height=320)
|
| 244 |
+
out_json = gr.JSON(label="Details")
|
| 245 |
+
btn2 = gr.Button("Generate Outfits", variant="primary")
|
| 246 |
+
btn2.click(fn=gradio_recommend, inputs=[inp2, occasion, weather, num_outfits], outputs=[out_gallery, out_json])
|
| 247 |
+
with gr.Tab("Embed (debug)"):
|
| 248 |
inp = gr.Files(label="Upload Items (multiple images)")
|
| 249 |
out = gr.Textbox(label="Embeddings (JSON)")
|
| 250 |
btn = gr.Button("Compute Embeddings")
|
| 251 |
btn.click(fn=gradio_embed, inputs=inp, outputs=out)
|
| 252 |
+
with gr.Tab("Train"):
|
| 253 |
+
gr.Markdown("Train models on Stylique/Polyvore (70/10/10 split). This runs on the Space hardware.")
|
| 254 |
+
epochs_res = gr.Slider(1, 50, value=10, step=1, label="ResNet epochs")
|
| 255 |
+
epochs_vit = gr.Slider(1, 100, value=20, step=1, label="ViT epochs")
|
| 256 |
+
train_log = gr.Textbox(label="Training Log", lines=10)
|
| 257 |
+
start_btn = gr.Button("Start Training")
|
| 258 |
+
|
| 259 |
+
def start_training(res_epochs: int, vit_epochs: int):
|
| 260 |
+
def _runner():
|
| 261 |
+
try:
|
| 262 |
+
import subprocess
|
| 263 |
+
if not DATASET_ROOT:
|
| 264 |
+
train_log.value = "Dataset not ready."
|
| 265 |
+
return
|
| 266 |
+
export_dir = os.getenv("EXPORT_DIR", "models/exports")
|
| 267 |
+
os.makedirs(export_dir, exist_ok=True)
|
| 268 |
+
train_log.value = "Training ResNet…\n"
|
| 269 |
+
subprocess.run([
|
| 270 |
+
"python", "train_resnet.py", "--data_root", DATASET_ROOT, "--epochs", str(res_epochs),
|
| 271 |
+
"--out", os.path.join(export_dir, "resnet_item_embedder.pth")
|
| 272 |
+
], check=False)
|
| 273 |
+
train_log.value += "\nTraining ViT (triplet)…\n"
|
| 274 |
+
subprocess.run([
|
| 275 |
+
"python", "train_vit_triplet.py", "--data_root", DATASET_ROOT, "--epochs", str(vit_epochs),
|
| 276 |
+
"--export", os.path.join(export_dir, "vit_outfit_model.pth")
|
| 277 |
+
], check=False)
|
| 278 |
+
service.reload_models()
|
| 279 |
+
train_log.value += "\nDone. Artifacts in models/exports."
|
| 280 |
+
except Exception as e:
|
| 281 |
+
train_log.value += f"\nError: {e}"
|
| 282 |
+
threading.Thread(target=_runner, daemon=True).start()
|
| 283 |
+
return "Started"
|
| 284 |
+
|
| 285 |
+
start_btn.click(fn=start_training, inputs=[epochs_res, epochs_vit], outputs=train_log)
|
| 286 |
with gr.Tab("Downloads"):
|
| 287 |
gr.Markdown("Download trained artifacts from models/exports")
|
| 288 |
file_list = gr.JSON(label="Artifacts JSON")
|
|
|
|
| 308 |
|
| 309 |
|
| 310 |
try:
|
| 311 |
+
# Mount Gradio onto FastAPI root path (disable SSR to avoid stray port fetches)
|
| 312 |
+
demo.queue()
|
| 313 |
app = gr.mount_gradio_app(app, demo, path="/")
|
| 314 |
except Exception:
|
| 315 |
# In case mounting fails in certain runners, we still want FastAPI to be available
|
|
|
|
| 325 |
|
| 326 |
|
| 327 |
if __name__ == "__main__":
|
| 328 |
+
# Local/Space run
|
| 329 |
+
demo.queue().launch(ssr_mode=False)
|
| 330 |
|
| 331 |
|
data/polyvore.py
CHANGED
|
@@ -7,6 +7,7 @@ from torch.utils.data import Dataset
|
|
| 7 |
from PIL import Image
|
| 8 |
|
| 9 |
from utils.transforms import build_train_transforms
|
|
|
|
| 10 |
|
| 11 |
|
| 12 |
class PolyvoreTripletDataset(Dataset):
|
|
@@ -31,11 +32,21 @@ class PolyvoreTripletDataset(Dataset):
|
|
| 31 |
with open(triplet_path, "r") as f:
|
| 32 |
self.samples: List[Dict[str, Any]] = json.load(f)
|
| 33 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 34 |
def _load_image(self, item_id: str) -> Image.Image:
|
| 35 |
-
|
| 36 |
-
img_path = os.path.join(self.root, "images", f"{item_id}.jpg")
|
| 37 |
-
if not os.path.exists(img_path):
|
| 38 |
-
raise FileNotFoundError(img_path)
|
| 39 |
return Image.open(img_path).convert("RGB")
|
| 40 |
|
| 41 |
def __len__(self) -> int:
|
|
@@ -74,9 +85,7 @@ class PolyvoreOutfitDataset(Dataset):
|
|
| 74 |
# If metadata isn't available, we will rely on count >= 3 and let model learn; here, keep as-is.
|
| 75 |
|
| 76 |
def _load_image(self, item_id: str) -> Image.Image:
|
| 77 |
-
img_path =
|
| 78 |
-
if not os.path.exists(img_path):
|
| 79 |
-
raise FileNotFoundError(img_path)
|
| 80 |
return Image.open(img_path).convert("RGB")
|
| 81 |
|
| 82 |
def __len__(self) -> int:
|
|
|
|
| 7 |
from PIL import Image
|
| 8 |
|
| 9 |
from utils.transforms import build_train_transforms
|
| 10 |
+
from pathlib import Path
|
| 11 |
|
| 12 |
|
| 13 |
class PolyvoreTripletDataset(Dataset):
|
|
|
|
| 32 |
with open(triplet_path, "r") as f:
|
| 33 |
self.samples: List[Dict[str, Any]] = json.load(f)
|
| 34 |
|
| 35 |
+
def _find_image_path(self, item_id: str) -> str:
|
| 36 |
+
base = os.path.join(self.root, "images")
|
| 37 |
+
# direct common extensions
|
| 38 |
+
for ext in (".jpg", ".jpeg", ".png", ".webp"):
|
| 39 |
+
p = os.path.join(base, f"{item_id}{ext}")
|
| 40 |
+
if os.path.isfile(p):
|
| 41 |
+
return p
|
| 42 |
+
# recursive fuzzy search
|
| 43 |
+
for p in Path(base).rglob(f"*{item_id}*"):
|
| 44 |
+
if p.suffix.lower() in (".jpg", ".jpeg", ".png", ".webp"):
|
| 45 |
+
return str(p)
|
| 46 |
+
raise FileNotFoundError(f"Image for item {item_id} not found under {base}")
|
| 47 |
+
|
| 48 |
def _load_image(self, item_id: str) -> Image.Image:
|
| 49 |
+
img_path = self._find_image_path(item_id)
|
|
|
|
|
|
|
|
|
|
| 50 |
return Image.open(img_path).convert("RGB")
|
| 51 |
|
| 52 |
def __len__(self) -> int:
|
|
|
|
| 85 |
# If metadata isn't available, we will rely on count >= 3 and let model learn; here, keep as-is.
|
| 86 |
|
| 87 |
def _load_image(self, item_id: str) -> Image.Image:
|
| 88 |
+
img_path = PolyvoreTripletDataset._find_image_path(self, item_id) # reuse logic
|
|
|
|
|
|
|
| 89 |
return Image.open(img_path).convert("RGB")
|
| 90 |
|
| 91 |
def __len__(self) -> int:
|
inference.py
CHANGED
|
@@ -89,54 +89,67 @@ class InferenceService:
|
|
| 89 |
|
| 90 |
@torch.inference_mode()
|
| 91 |
def compose_outfits(self, items: List[Dict[str, Any]], context: Dict[str, Any]) -> List[Dict[str, Any]]:
|
| 92 |
-
# Ensure embeddings
|
| 93 |
proc_items: List[Dict[str, Any]] = []
|
| 94 |
for it in items:
|
| 95 |
-
|
| 96 |
-
if
|
| 97 |
-
#
|
| 98 |
emb = self.embed_images([it["image"]])[0]
|
| 99 |
-
|
| 100 |
-
#
|
| 101 |
continue
|
| 102 |
-
|
| 103 |
-
|
| 104 |
-
|
|
|
|
|
|
|
|
|
|
| 105 |
|
| 106 |
if len(proc_items) < 2:
|
| 107 |
return []
|
| 108 |
|
| 109 |
-
# Candidate generation
|
| 110 |
-
rng = np.random.default_rng(42)
|
| 111 |
num_outfits = int(context.get("num_outfits", 3))
|
| 112 |
-
min_size, max_size =
|
| 113 |
-
candidates: List[List[int]] = []
|
| 114 |
ids = list(range(len(proc_items)))
|
| 115 |
-
# slot-aware sampling if categories exist
|
| 116 |
-
def has_cat(i: int, cat_prefix: str) -> bool:
|
| 117 |
-
c = (proc_items[i].get("category") or "").lower()
|
| 118 |
-
return cat_prefix in c
|
| 119 |
|
| 120 |
-
|
| 121 |
-
|
| 122 |
-
|
| 123 |
-
accs = [i for i in ids if any(k in (proc_items[i].get("category") or "").lower() for k in ["watch", "belt", "ring", "bracelet", "accessor"]) ]
|
| 124 |
|
| 125 |
-
for
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 126 |
if uppers and bottoms and shoes and accs:
|
| 127 |
-
subset = [
|
| 128 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 129 |
remain = list(set(ids) - set(subset))
|
| 130 |
if remain and rng.random() < 0.5:
|
| 131 |
-
subset.append(rng.choice(remain)
|
| 132 |
else:
|
| 133 |
-
k = rng.integers(min_size, max_size + 1)
|
| 134 |
-
subset = rng.choice(ids, size=
|
| 135 |
candidates.append(subset)
|
| 136 |
|
| 137 |
-
# Score using ViT
|
| 138 |
def score_subset(idx_subset: List[int]) -> float:
|
| 139 |
-
embs = torch.tensor(
|
|
|
|
|
|
|
|
|
|
|
|
|
| 140 |
embs = embs.unsqueeze(0) # (1, N, D)
|
| 141 |
s = self.vit.score_compatibility(embs).item()
|
| 142 |
return float(s)
|
|
|
|
| 89 |
|
| 90 |
@torch.inference_mode()
|
| 91 |
def compose_outfits(self, items: List[Dict[str, Any]], context: Dict[str, Any]) -> List[Dict[str, Any]]:
|
| 92 |
+
# 1) Ensure embeddings for each input item
|
| 93 |
proc_items: List[Dict[str, Any]] = []
|
| 94 |
for it in items:
|
| 95 |
+
emb = it.get("embedding")
|
| 96 |
+
if emb is None and it.get("image") is not None:
|
| 97 |
+
# Compute on-the-fly if image provided
|
| 98 |
emb = self.embed_images([it["image"]])[0]
|
| 99 |
+
if emb is None:
|
| 100 |
+
# Skip if we cannot get an embedding
|
| 101 |
continue
|
| 102 |
+
emb_np = np.asarray(emb, dtype=np.float32)
|
| 103 |
+
proc_items.append({
|
| 104 |
+
"id": it.get("id"),
|
| 105 |
+
"embedding": emb_np,
|
| 106 |
+
"category": it.get("category")
|
| 107 |
+
})
|
| 108 |
|
| 109 |
if len(proc_items) < 2:
|
| 110 |
return []
|
| 111 |
|
| 112 |
+
# 2) Candidate generation
|
| 113 |
+
rng = np.random.default_rng(int(context.get("seed", 42)))
|
| 114 |
num_outfits = int(context.get("num_outfits", 3))
|
| 115 |
+
min_size, max_size = 4, 6
|
|
|
|
| 116 |
ids = list(range(len(proc_items)))
|
|
|
|
|
|
|
|
|
|
|
|
|
| 117 |
|
| 118 |
+
# Slot-aware pools from categories (best-effort)
|
| 119 |
+
def cat_str(i: int) -> str:
|
| 120 |
+
return (proc_items[i].get("category") or "").lower()
|
|
|
|
| 121 |
|
| 122 |
+
uppers = [i for i in ids if any(k in cat_str(i) for k in ["top", "shirt", "tshirt", "blouse", "jacket", "hoodie"])]
|
| 123 |
+
bottoms = [i for i in ids if any(k in cat_str(i) for k in ["pant", "trouser", "jean", "skirt", "short"])]
|
| 124 |
+
shoes = [i for i in ids if any(k in cat_str(i) for k in ["shoe", "sneaker", "boot", "heel"])]
|
| 125 |
+
accs = [i for i in ids if any(k in cat_str(i) for k in ["watch", "belt", "ring", "bracelet", "accessor", "bag", "hat"])]
|
| 126 |
+
|
| 127 |
+
candidates: List[List[int]] = []
|
| 128 |
+
num_samples = max(num_outfits * 12, 24)
|
| 129 |
+
for _ in range(num_samples):
|
| 130 |
if uppers and bottoms and shoes and accs:
|
| 131 |
+
subset = [
|
| 132 |
+
int(rng.choice(uppers)),
|
| 133 |
+
int(rng.choice(bottoms)),
|
| 134 |
+
int(rng.choice(shoes)),
|
| 135 |
+
int(rng.choice(accs)),
|
| 136 |
+
]
|
| 137 |
+
# Optional: add one more random distinct item
|
| 138 |
remain = list(set(ids) - set(subset))
|
| 139 |
if remain and rng.random() < 0.5:
|
| 140 |
+
subset.append(int(rng.choice(remain)))
|
| 141 |
else:
|
| 142 |
+
k = int(rng.integers(min_size, max_size + 1))
|
| 143 |
+
subset = list(map(int, rng.choice(ids, size=k, replace=False).tolist()))
|
| 144 |
candidates.append(subset)
|
| 145 |
|
| 146 |
+
# 3) Score using ViT
|
| 147 |
def score_subset(idx_subset: List[int]) -> float:
|
| 148 |
+
embs = torch.tensor(
|
| 149 |
+
np.stack([proc_items[i]["embedding"] for i in idx_subset], axis=0),
|
| 150 |
+
dtype=torch.float32,
|
| 151 |
+
device=self.device,
|
| 152 |
+
) # (N, D)
|
| 153 |
embs = embs.unsqueeze(0) # (1, N, D)
|
| 154 |
s = self.vit.score_compatibility(embs).item()
|
| 155 |
return float(s)
|
utils/data_fetch.py
CHANGED
|
@@ -1,148 +1,55 @@
|
|
| 1 |
import os
|
| 2 |
-
import shutil
|
| 3 |
import zipfile
|
| 4 |
from pathlib import Path
|
| 5 |
-
from typing import Optional
|
| 6 |
|
| 7 |
-
import
|
| 8 |
|
| 9 |
-
try:
|
| 10 |
-
from huggingface_hub import snapshot_download # type: ignore
|
| 11 |
-
except Exception: # pragma: no cover
|
| 12 |
-
snapshot_download = None
|
| 13 |
|
| 14 |
-
|
| 15 |
-
|
| 16 |
-
|
| 17 |
-
|
| 18 |
-
|
| 19 |
-
|
| 20 |
-
|
| 21 |
-
|
| 22 |
-
def _download_zip(url: str, dest_dir: str) -> str:
|
| 23 |
-
os.makedirs(dest_dir, exist_ok=True)
|
| 24 |
-
local_zip = os.path.join(dest_dir, "dataset.zip")
|
| 25 |
-
with requests.get(url, stream=True, timeout=60) as r:
|
| 26 |
-
r.raise_for_status()
|
| 27 |
-
with open(local_zip, "wb") as f:
|
| 28 |
-
for chunk in r.iter_content(chunk_size=1024 * 1024):
|
| 29 |
-
if chunk:
|
| 30 |
-
f.write(chunk)
|
| 31 |
-
with zipfile.ZipFile(local_zip, "r") as zf:
|
| 32 |
-
zf.extractall(dest_dir)
|
| 33 |
-
os.remove(local_zip)
|
| 34 |
-
return dest_dir
|
| 35 |
-
|
| 36 |
-
|
| 37 |
-
def _unzip_inner_archives(root: str) -> None:
|
| 38 |
-
"""Find and extract any zip files inside root (e.g., images.zip)."""
|
| 39 |
-
for dirpath, _dirnames, filenames in os.walk(root):
|
| 40 |
-
for fn in filenames:
|
| 41 |
-
if fn.lower().endswith(".zip"):
|
| 42 |
-
zpath = os.path.join(dirpath, fn)
|
| 43 |
-
try:
|
| 44 |
-
with zipfile.ZipFile(zpath, "r") as zf:
|
| 45 |
-
zf.extractall(dirpath)
|
| 46 |
-
# keep original zip to avoid repeated work? remove to save disk
|
| 47 |
-
try:
|
| 48 |
-
os.remove(zpath)
|
| 49 |
-
except Exception:
|
| 50 |
-
pass
|
| 51 |
-
except Exception as e: # pragma: no cover
|
| 52 |
-
print(f"Failed to unzip inner archive {zpath}: {e}")
|
| 53 |
-
|
| 54 |
-
|
| 55 |
-
def _ensure_images_dir(root: str) -> None:
|
| 56 |
-
"""Ensure a stable images/ path exists under root. Create a symlink if needed."""
|
| 57 |
-
images_root = os.path.join(root, "images")
|
| 58 |
-
if os.path.isdir(images_root):
|
| 59 |
return
|
| 60 |
-
#
|
| 61 |
-
|
| 62 |
-
|
| 63 |
-
|
| 64 |
-
|
| 65 |
-
|
| 66 |
-
|
| 67 |
-
|
| 68 |
-
|
| 69 |
-
|
| 70 |
-
|
| 71 |
-
if candidate_dirs:
|
| 72 |
-
src = candidate_dirs[0]
|
| 73 |
-
try:
|
| 74 |
-
os.symlink(src, images_root)
|
| 75 |
-
print(f"Created images symlink: {images_root} -> {src}")
|
| 76 |
-
except Exception:
|
| 77 |
-
# fallback: create folder and leave it empty (training will fail fast if missing)
|
| 78 |
-
os.makedirs(images_root, exist_ok=True)
|
| 79 |
-
else:
|
| 80 |
-
os.makedirs(images_root, exist_ok=True)
|
| 81 |
|
| 82 |
|
| 83 |
def ensure_dataset_ready() -> Optional[str]:
|
| 84 |
"""
|
| 85 |
-
|
| 86 |
-
|
| 87 |
-
|
| 88 |
-
|
| 89 |
-
3) If DATA_ZIP_URL is set, download and unzip
|
| 90 |
-
4) Try KaggleHub (best-effort)
|
| 91 |
-
Returns resolved root path or None if nothing done.
|
| 92 |
"""
|
| 93 |
-
root = os.
|
| 94 |
-
auto_fetch = os.getenv("AUTO_FETCH_DATA", "true").lower() == "true"
|
| 95 |
Path(root).mkdir(parents=True, exist_ok=True)
|
| 96 |
|
| 97 |
-
#
|
| 98 |
-
|
| 99 |
-
|
| 100 |
-
_ensure_images_dir(root)
|
| 101 |
return root
|
| 102 |
-
if not auto_fetch:
|
| 103 |
-
return None
|
| 104 |
|
| 105 |
-
#
|
| 106 |
-
|
| 107 |
-
|
| 108 |
-
|
| 109 |
-
|
| 110 |
-
|
| 111 |
-
_ensure_images_dir(root)
|
| 112 |
-
# If splits not provided, they'll be prepared by the caller
|
| 113 |
-
return root
|
| 114 |
-
except Exception as e: # pragma: no cover
|
| 115 |
-
print(f"HF dataset download failed: {e}")
|
| 116 |
-
|
| 117 |
-
# Try ZIP URL
|
| 118 |
-
zip_url = os.getenv("DATA_ZIP_URL")
|
| 119 |
-
if zip_url:
|
| 120 |
-
try:
|
| 121 |
-
_download_zip(zip_url, root)
|
| 122 |
-
_unzip_inner_archives(root)
|
| 123 |
-
_ensure_images_dir(root)
|
| 124 |
-
except Exception as e: # pragma: no cover
|
| 125 |
-
print(f"ZIP download failed: {e}")
|
| 126 |
-
return None
|
| 127 |
-
|
| 128 |
-
# Try KaggleHub (no Kaggle keys required for public datasets)
|
| 129 |
-
if kagglehub is not None and KaggleDatasetAdapter is not None:
|
| 130 |
-
try:
|
| 131 |
-
# Attempt to load core file to trigger dataset download locally
|
| 132 |
-
# User can override POLYVORE_FILE_PATH to select a specific CSV/JSON
|
| 133 |
-
file_path = os.getenv("POLYVORE_FILE_PATH", "")
|
| 134 |
-
kagglehub.load_dataset(
|
| 135 |
-
KaggleDatasetAdapter.PANDAS,
|
| 136 |
-
"dnepozitek/polyvore-outfits",
|
| 137 |
-
file_path,
|
| 138 |
-
)
|
| 139 |
-
# KaggleHub stores under ~/.cache/kagglehub/datasets/<slug>/...; copy to root if needed
|
| 140 |
-
# For simplicity, assume user will run prepare script using POLYVORE_ROOT pointing to extracted images
|
| 141 |
-
_unzip_inner_archives(root)
|
| 142 |
-
_ensure_images_dir(root)
|
| 143 |
-
except Exception as e: # pragma: no cover
|
| 144 |
-
print(f"KaggleHub download failed: {e}")
|
| 145 |
|
| 146 |
-
|
|
|
|
|
|
|
| 147 |
|
| 148 |
|
|
|
|
| 1 |
import os
|
|
|
|
| 2 |
import zipfile
|
| 3 |
from pathlib import Path
|
| 4 |
+
from typing import Optional
|
| 5 |
|
| 6 |
+
from huggingface_hub import snapshot_download # type: ignore
|
| 7 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 8 |
|
| 9 |
+
def _unzip_images_if_needed(root: str) -> None:
|
| 10 |
+
"""
|
| 11 |
+
If an archive like images.zip exists in the dataset root, extract it to root/images.
|
| 12 |
+
"""
|
| 13 |
+
images_dir = os.path.join(root, "images")
|
| 14 |
+
if os.path.isdir(images_dir) and any(Path(images_dir).glob("*")):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 15 |
return
|
| 16 |
+
# Common zip names at root or subfolders
|
| 17 |
+
candidates = [os.path.join(root, name) for name in ("images.zip", "polyvore-images.zip", "imgs.zip")]
|
| 18 |
+
# Also search recursively for any *images*.zip
|
| 19 |
+
for p in Path(root).rglob("*images*.zip"):
|
| 20 |
+
candidates.append(str(p))
|
| 21 |
+
for zpath in candidates:
|
| 22 |
+
if os.path.isfile(zpath):
|
| 23 |
+
os.makedirs(images_dir, exist_ok=True)
|
| 24 |
+
with zipfile.ZipFile(zpath, "r") as zf:
|
| 25 |
+
zf.extractall(images_dir)
|
| 26 |
+
return
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 27 |
|
| 28 |
|
| 29 |
def ensure_dataset_ready() -> Optional[str]:
|
| 30 |
"""
|
| 31 |
+
Self-contained dataset fetcher for the Polyvore dataset from Hugging Face.
|
| 32 |
+
- Downloads the dataset repo Stylique/Polyvore into ./data/Polyvore
|
| 33 |
+
- Unzips images.zip into ./data/Polyvore/images
|
| 34 |
+
- Returns the dataset root path
|
|
|
|
|
|
|
|
|
|
| 35 |
"""
|
| 36 |
+
root = os.path.abspath(os.path.join(os.getcwd(), "data", "Polyvore"))
|
|
|
|
| 37 |
Path(root).mkdir(parents=True, exist_ok=True)
|
| 38 |
|
| 39 |
+
# If already present, ensure images are unzipped and return
|
| 40 |
+
_unzip_images_if_needed(root)
|
| 41 |
+
if os.path.isdir(os.path.join(root, "images")):
|
|
|
|
| 42 |
return root
|
|
|
|
|
|
|
| 43 |
|
| 44 |
+
# Download the HF dataset snapshot into root
|
| 45 |
+
try:
|
| 46 |
+
snapshot_download("Stylique/Polyvore", repo_type="dataset", local_dir=root, local_dir_use_symlinks=False)
|
| 47 |
+
except Exception as e: # pragma: no cover
|
| 48 |
+
print(f"Failed to download Stylique/Polyvore dataset: {e}")
|
| 49 |
+
return None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 50 |
|
| 51 |
+
# Unzip images if needed
|
| 52 |
+
_unzip_images_if_needed(root)
|
| 53 |
+
return root if os.path.isdir(os.path.join(root, "images")) else None
|
| 54 |
|
| 55 |
|