buzzbandit commited on
Commit
21ec64b
Β·
verified Β·
1 Parent(s): 3ea0f12

try fix again

Browse files
Files changed (1) hide show
  1. app.py +221 -0
app.py CHANGED
@@ -331,6 +331,227 @@ COUNTRY_ALIASES = {
331
  "japan": "JAP", "china": "CHI", "hawaii": "HAW"
332
  }
333
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
334
  def normalize_country_query(q: str) -> str | None:
335
  q = (q or "").strip().lower()
336
  if not q:
 
331
  "japan": "JAP", "china": "CHI", "hawaii": "HAW"
332
  }
333
 
334
+ def normalize_country_query(q: str) -> str | None:
335
+ q = (q or "").strip().lower()
336
+ if not q:
337
+ return None
338
+ if q in COUNTRY_ALIASES:
339
+ return COUNTRY_ALIASES[q]
340
+ if len(q) == 3 and q.upper() in COUNTRY_NAMES:
341
+ return q.upper()
342
+ return None
343
+
344
+ def parse_freeform_query(text: str):
345
+ """Supports 'plushies in uk' and 'uk plushies'."""
346
+ if not text:
347
+ return "", ""
348
+ text = text.strip().lower()
349
+ m = re.match(r"(.+?)\\s+in\\s+(.+)", text, flags=re.IGNORECASE)
350
+ if m:
351
+ return m.group(1).strip(), m.group(2).strip()
352
+ parts = text.split()
353
+ if len(parts) == 2:
354
+ first, second = parts
355
+ if normalize_country_query(first):
356
+ return second, first
357
+ elif normalize_country_query(second):
358
+ return first, second
359
+ return text, ""
360
+
361
+ # ---------------- Semantic Match ----------------
362
+ def semantic_match(query, top_k=15, debug_top_n=5):
363
+ if not query:
364
+ return {"category": None, "items": []}
365
+
366
+ query = query.strip().lower()
367
+ q_emb = embedder.encode(query, convert_to_tensor=True)
368
+ sims_items = {n: float(util.cos_sim(q_emb, emb)) for n, emb in ITEM_EMBEDS.items()}
369
+ ranked_items = sorted(sims_items.items(), key=lambda x: x[1], reverse=True)
370
+ item_hits = [n for n, score in ranked_items[:top_k] if score > 0.35]
371
+
372
+ sims_cats = {c: float(util.cos_sim(q_emb, emb)) for c, emb in CATEGORY_EMBEDS.items()}
373
+ ranked_cats = sorted(sims_cats.items(), key=lambda x: x[1], reverse=True)
374
+ top_cat, cat_score = (ranked_cats[0] if ranked_cats else (None, 0.0))
375
+
376
+ related_items = []
377
+ if top_cat and cat_score > 0.35:
378
+ related_items = [n for n, t in ITEM_TO_TYPE.items() if t == top_cat]
379
+
380
+ combined = list(set(item_hits + related_items))
381
+ return {"category": top_cat if related_items else None, "items": combined}
382
+
383
+ # ---------------- Fetch YATA ----------------
384
+ def fetch_yata(force_refresh=False):
385
+ if not force_refresh and _cache["data"] and (time.time() - _cache["timestamp"] < 300):
386
+ return _cache["data"], _cache["last_update"]
387
+ try:
388
+ resp = requests.get(API_URL, timeout=10)
389
+ resp.raise_for_status()
390
+ data = resp.json()
391
+ _cache.update({
392
+ "data": data,
393
+ "timestamp": time.time(),
394
+ "last_update": time.strftime("%Y-%m-%d %H:%M:%S", time.localtime())
395
+ })
396
+ print(f"βœ… Fetched YATA data at {_cache['last_update']}")
397
+ return data, _cache["last_update"]
398
+ except Exception as e:
399
+ print(f"❌ Fetch error: {e}")
400
+ return {"stocks": {}}, "Fetch failed"
401
+
402
+ # ---------------- Core logic ----------------
403
+ def query_inventory(query_text="", category="", country_name="", capacity=10, refresh=False):
404
+ data, last_update = fetch_yata(force_refresh=refresh)
405
+ stocks = data.get("stocks", {})
406
+ rows = []
407
+
408
+ parsed_item, parsed_country = parse_freeform_query(query_text)
409
+ if not country_name and parsed_country:
410
+ country_name = parsed_country
411
+ item_term = parsed_item
412
+ semantic_result = semantic_match(item_term) if item_term else {"category": None, "items": []}
413
+ semantic_items = semantic_result["items"]
414
+ semantic_category = semantic_result["category"]
415
+
416
+ user_code = normalize_country_query(country_name)
417
+
418
+ for code_raw, cdata in stocks.items():
419
+ code = code_raw.upper()
420
+ cname = COUNTRY_NAMES.get(code, code)
421
+ if country_name:
422
+ if user_code:
423
+ country_ok = (user_code == code)
424
+ else:
425
+ country_ok = country_name.lower() in cname.lower()
426
+ else:
427
+ country_ok = True
428
+
429
+ if not country_ok:
430
+ continue
431
+
432
+ for item in cdata.get("stocks", []):
433
+ iname = item.get("name", "")
434
+ itype = ITEM_TO_TYPE.get(iname, "").lower()
435
+ qty = item.get("quantity", 0)
436
+ cost = item.get("cost", 0)
437
+
438
+ if item_term:
439
+ item_ok = (
440
+ item_term.lower() in iname.lower()
441
+ or iname in semantic_items
442
+ or (semantic_category and itype == semantic_category.lower())
443
+ )
444
+ elif category:
445
+ item_ok = category.lower() == itype
446
+ else:
447
+ item_ok = True
448
+
449
+ if item_ok:
450
+ rows.append({
451
+ "Country": cname,
452
+ "Item": iname,
453
+ "Category": itype.title(),
454
+ "Quantity": qty,
455
+ "Cost": cost,
456
+ "Max Capacity Cost": cost * capacity,
457
+ })
458
+
459
+ if not rows:
460
+ return pd.DataFrame([{"Result": "No inventory found for that query."}]), f"Last update: {last_update}"
461
+
462
+ df = pd.DataFrame(rows)
463
+ df = df.sort_values(by=["Country", "Item"])
464
+ for col in ["Quantity", "Cost", "Max Capacity Cost"]:
465
+ if col in df.columns:
466
+ df[col] = df[col].apply(lambda x: f"{x:,.0f}" if isinstance(x, (int, float)) and x != "" else x)
467
+ return df, f"Last update: {last_update}"
468
+
469
+ # ---------------- Wrapper ----------------
470
+ def run_query(query_text, category, country, capacity, refresh):
471
+ return query_inventory(query_text, category, country, capacity, refresh)
472
+
473
+ def init_capacity(saved_capacity):
474
+ try:
475
+ if saved_capacity and 5 <= float(saved_capacity) <= 88:
476
+ return float(saved_capacity)
477
+ except Exception:
478
+ pass
479
+ return 10
480
+
481
+ # ---------------- Gradio UI ----------------
482
+ with gr.Blocks(title="🧳 Torn Inventory Viewer") as iface:
483
+ gr.Markdown("## 🧳 Torn Inventory Viewer")
484
+ gr.Markdown("_Search Torn YATA travel stocks with smart semantic matching._")
485
+
486
+ query_box = gr.Textbox(label="Search (semantic, e.g. 'flowers in England')")
487
+ category_drop = gr.Dropdown(label="Category (optional exact match)", choices=[""] + ALL_CATEGORIES)
488
+ country_box = gr.Textbox(label="Country (optional, e.g. UK, Cayman, Japan)")
489
+ capacity_slider = gr.Number(label="Travel Capacity", value=10, minimum=5, maximum=88, precision=0)
490
+ refresh_check = gr.Checkbox(label="Force refresh (ignore cache)", value=False)
491
+ saved_capacity_state = gr.State(value=None)
492
+
493
+ iface.load(init_capacity, inputs=[saved_capacity_state], outputs=[capacity_slider])
494
+
495
+ run_btn = gr.Button("πŸ” Search / Refresh")
496
+ result_df = gr.Dataframe(label="Results")
497
+ meta_box = gr.Textbox(label="Metadata / Last Update")
498
+
499
+ run_btn.click(run_query,
500
+ inputs=[query_box, category_drop, country_box, capacity_slider, refresh_check],
501
+ outputs=[result_df, meta_box])
502
+
503
+ gr.HTML("""
504
+ <script>
505
+ function syncCapacity() {
506
+ const saved = localStorage.getItem('travel_capacity');
507
+ const stateEl = document.querySelector('textarea[name="saved_capacity_state"]');
508
+ if (stateEl && saved) {
509
+ stateEl.value = saved;
510
+ stateEl.dispatchEvent(new Event('input', { bubbles: true }));
511
+ }
512
+ const capField = document.querySelector('input[type=number]');
513
+ if (capField) {
514
+ capField.addEventListener('input', () => {
515
+ localStorage.setItem('travel_capacity', capField.value);
516
+ });
517
+ } else {
518
+ setTimeout(syncCapacity, 400);
519
+ }
520
+ }
521
+ setTimeout(syncCapacity, 800);
522
+ </script>
523
+ """)
524
+
525
+ try:
526
+ iface.launch()
527
+ except Exception as e:
528
+ import traceback
529
+ print("❌ Failed to start:")
530
+ traceback.print_exc()if not CATEGORY_ALIASES or not CATEGORY_EMBEDS:
531
+ CATEGORY_ALIASES = auto_alias_categories(embedder, ALL_CATEGORIES, list(ITEM_TO_TYPE.keys()))
532
+ CATEGORY_EMBEDS = {
533
+ cat: sum([embedder.encode(a, convert_to_tensor=True) for a in aliases]) / len(aliases)
534
+ for cat, aliases in CATEGORY_ALIASES.items()
535
+ }
536
+ save_cached_embeddings(CATEGORY_ALIASES, CATEGORY_EMBEDS)
537
+ else:
538
+ print("βœ… Using cached dynamic category embeddings.")
539
+
540
+ # ---------------- Country mapping ----------------
541
+ COUNTRY_NAMES = {
542
+ "ARG": "Argentina", "MEX": "Mexico", "CAN": "Canada", "UNI": "United Kingdom",
543
+ "JAP": "Japan", "SOU": "South Africa", "SWI": "Switzerland", "UAE": "United Arab Emirates",
544
+ "CHI": "China", "HAW": "Hawaii", "CAY": "Cayman Islands"
545
+ }
546
+ COUNTRY_ALIASES = {
547
+ "uk": "UNI", "england": "UNI", "united kingdom": "UNI",
548
+ "uae": "UAE", "united arab emirates": "UAE",
549
+ "south africa": "SOU", "switzerland": "SWI",
550
+ "cayman": "CAY", "cayman islands": "CAY",
551
+ "argentina": "ARG", "mexico": "MEX", "canada": "CAN",
552
+ "japan": "JAP", "china": "CHI", "hawaii": "HAW"
553
+ }
554
+
555
  def normalize_country_query(q: str) -> str | None:
556
  q = (q or "").strip().lower()
557
  if not q: