Spaces:
Sleeping
Sleeping
# app.py โ GIftyPlus (lean) | |
# ----------------------------------------------------------------------------- | |
# High-level overview | |
# ----------------------------------------------------------------------------- | |
# GIftyPlus is a lightweight gift recommender + DIY generator. | |
# Pipeline: | |
# 1) Load & normalize an Amazon-like product dataset (name/desc/tags/price/img). | |
# 2) Build sentence embeddings for semantic retrieval (cached to .npy). | |
# 3) Rank items with a weighted score (embeddings + optional cross-encoder + | |
# interest/occasion/price bonuses) and diversify with MMR. | |
# 4) Generate a DIY gift idea (FLAN-T5), then embed 10 candidates and append | |
# the best one as a "Generated" #4 result. | |
# 5) Generate a short personalized message (FLAN-T5) with basic validators. | |
# 6) Gradio UI: input form, input summary, top-3 + generated #4, DIY section, | |
# and personalized message section. | |
# | |
# Env vars you can override: | |
# DATASET_ID, DATASET_SPLIT, MAX_ROWS, | |
# EMBED_MODEL_ID, RERANK_MODEL_ID, | |
# DIY_MODEL_ID, MAX_INPUT_TOKENS, DIY_MAX_NEW_TOKENS. | |
# ----------------------------------------------------------------------------- | |
import os, re, json, hashlib, pathlib, random | |
from typing import Dict, List, Tuple, Optional, Any | |
import numpy as np, pandas as pd, gradio as gr, torch | |
from datasets import load_dataset | |
from sentence_transformers import SentenceTransformer | |
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM | |
TITLE = "# ๐ GIftyPlus - Smart Gift Recommender\n*Top-3 catalog picks + 1 DIY gift + personalized message*" | |
DATASET_ID = os.getenv("DATASET_ID", "Danielos100/Amazon_products_clean") | |
DATASET_SPLIT = os.getenv("DATASET_SPLIT", "train") | |
MAX_ROWS = int(os.getenv("MAX_ROWS", "12000")) | |
EMBED_MODEL_ID = os.getenv("EMBED_MODEL_ID", "sentence-transformers/all-MiniLM-L12-v2") | |
def resolve_cache_dir(): | |
# Choose the first writable cache directory: | |
# 1) EMBED_CACHE_DIR env, 2) project .gifty_cache, 3) /tmp/.gifty_cache | |
for p in [os.getenv("EMBED_CACHE_DIR"), os.path.join(os.getcwd(), ".gifty_cache"), "/tmp/.gifty_cache"]: | |
if not p: continue | |
pathlib.Path(p).mkdir(parents=True, exist_ok=True) | |
with open(os.path.join(p, ".write_test"), "w") as f: f.write("ok") | |
pathlib.Path(os.path.join(p, ".write_test")).unlink(missing_ok=True) | |
return p | |
return os.getcwd() | |
EMBED_CACHE_DIR = resolve_cache_dir() | |
# UI vocab / options | |
INTEREST_OPTIONS = ["Sports","Travel","Cooking","Technology","Music","Art","Reading","Gardening","Fashion","Gaming","Photography","Hiking","Movies","Crafts","Pets","Wellness","Collecting","Food","Home decor","Science"] | |
OCCASION_UI = ["Birthday","Wedding / Engagement","Anniversary","Graduation","New baby","Housewarming","Retirement","Holidays","Valentineโs Day","Promotion / New job","Get well soon"] | |
OCCASION_CANON = {"Birthday":"birthday","Wedding / Engagement":"wedding","Anniversary":"anniversary","Graduation":"graduation","New baby":"new_baby","Housewarming":"housewarming","Retirement":"retirement","Holidays":"holidays","Valentineโs Day":"valentines","Promotion / New job":"promotion","Get well soon":"get_well"} | |
RECIPIENT_RELATIONSHIPS = ["Family - Parent","Family - Sibling","Family - Child","Family - Other relative","Friend","Colleague","Boss","Romantic partner","Teacher / Mentor","Neighbor","Client / Business partner"] | |
MESSAGE_TONES = ["Formal","Casual","Funny","Heartfelt","Inspirational","Playful","Romantic","Appreciative","Encouraging"] | |
AGE_OPTIONS = {"any":"any","kid (3โ12)":"kids","teen (13โ17)":"teens","adult (18โ64)":"adult","senior (65+)":"senior"} | |
GENDER_OPTIONS = ["any","female","male","nonbinary"] | |
# Light synonym expansion for interests; used to enrich queries and "hit" checks | |
SYNONYMS = {"sports":["fitness","outdoor","training","yoga","run"],"travel":["luggage","passport","map","trip","vacation"],"cooking":["kitchen","cookware","chef","baking"],"technology":["electronics","gadgets","device","smart","computer"],"music":["audio","headphones","earbuds","speaker","vinyl"],"art":["painting","drawing","sketch","canvas"],"reading":["book","novel","literature"],"gardening":["plants","planter","seeds","garden","indoor"],"fashion":["style","accessory","jewelry"],"gaming":["board game","puzzle","video game","controller"],"photography":["camera","lens","tripod","film"],"hiking":["outdoor","camping","backpack","trek"],"movies":["film","cinema","blu-ray","poster"],"crafts":["diy","handmade","kit","knitting"],"pets":["dog","cat","pet"],"wellness":["relaxation","spa","aromatherapy","self-care"],"collecting":["display","collector","limited edition"],"food":["gourmet","snack","treats","chocolate"],"home decor":["home","decor","wall art","candle"],"science":["lab","experiment","STEM","microscope"]} | |
REL_TO_TOKENS = {"Family - Parent":["parent","family"],"Family - Sibling":["sibling","family"],"Family - Child":["kids","play","family"],"Family - Other relative":["family","relative"],"Friend":["friendly"],"Colleague":["office","work","professional"],"Boss":["executive","professional","premium"],"Romantic partner":["romantic","couple"],"Teacher / Mentor":["teacher","mentor","thank_you"],"Neighbor":["neighbor","housewarming"],"Client / Business partner":["professional","thank_you","premium"]} | |
# --- Price parsing helpers (robust to currency symbols and ranges) --- | |
_CURRENCY_RE = re.compile(r"[^\d.,\-]+"); _NUM_RE = re.compile(r"(\d+(?:[.,]\d+)?)"); _RANGE_SEP = re.compile(r"\s*(?:-|โ|โ|to)\s*") | |
def _to_price_usd(x): | |
if pd.isna(x): return np.nan | |
s = str(x).strip().lower() | |
if _RANGE_SEP.search(s): s = _RANGE_SEP.split(s)[0] | |
s = _CURRENCY_RE.sub(" ", s); m = _NUM_RE.search(s.replace(",", ".")) | |
return float(m.group(1)) if m else np.nan | |
def _first_present(df, cands): | |
# Return the first column name that exists in df out of candidates (case-insensitive) | |
lower = {c.lower(): c for c in df.columns} | |
for c in cands: | |
if c in df.columns: return c | |
if c.lower() in lower: return lower[c.lower()] | |
return None | |
def _auto_price_col(df): | |
# Heuristics for price column detection when column name is unknown | |
for c in df.columns: | |
s = df[c] | |
if pd.api.types.is_numeric_dtype(s) and not s.dropna().empty and (s.dropna().between(0.5, 10000)).mean() > .6: return c | |
for c in df.columns: | |
if df[c].astype(str).head(200).str.lower().str.contains(r"\$|โช|eur|usd|ยฃ|โฌ|\d").mean() > .5: return c | |
return None | |
def map_amazon_to_schema(raw: pd.DataFrame) -> pd.DataFrame: | |
# Map arbitrary Amazon-like columns into a compact schema suitable for retrieval | |
name_c=_first_present(raw,["product name","title","name","product_title"]); desc_c=_first_present(raw,["description","product_description","feature","about"]) | |
cat_c=_first_present(raw,["category","categories","main_cat","product_category"]); price_c=_first_present(raw,["selling price","price","current_price","list_price","price_amount","actual_price","price_usd"]) or _auto_price_col(raw) | |
img_c=_first_present(raw,["image","image_url","imageurl","imUrl","img","img_url"]) | |
df=pd.DataFrame({"name":raw.get(name_c,""),"short_desc":raw.get(desc_c,""),"tags":raw.get(cat_c,""),"price_usd":raw.get(price_c,np.nan),"image_url":raw.get(img_c,"")}) | |
# Light normalization / truncation to keep UI compact | |
df["price_usd"]=df["price_usd"].map(_to_price_usd); df["name"]=df["name"].astype(str).str.strip().str.slice(0,160) | |
df["short_desc"]=df["short_desc"].astype(str).str.strip().str.slice(0,600); df["tags"]=df["tags"].astype(str).str.replace("|",", ").str.lower() | |
return df | |
def extract_top_cat(tags:str)->str: | |
# Extract a "top-level" category token for quick grouping/labeling | |
s=(tags or "").lower() | |
for sep in ["|",">"]: | |
if sep in s: return s.split(sep,1)[0].strip() | |
return s.strip().split(",")[0] if s else "" | |
def load_catalog()->pd.DataFrame: | |
# Load dataset โ normalize schema โ filter โ light feature engineering | |
df=map_amazon_to_schema(load_dataset(DATASET_ID, split=DATASET_SPLIT).to_pandas()).drop_duplicates(subset=["name","short_desc"]) | |
df=df[pd.notna(df["price_usd"])]; df=df[(df["price_usd"]>0)&(df["price_usd"]<=500)].reset_index(drop=True) | |
if len(df)>MAX_ROWS: df=df.sample(n=MAX_ROWS,random_state=42).reset_index(drop=True) | |
df["doc"]=(df["name"].fillna("")+" | "+df["tags"].fillna("")+" | "+df["short_desc"].fillna("")).str.strip() | |
df["top_cat"]=df["tags"].map(extract_top_cat) | |
df["blob"]=(df["name"].fillna("")+" "+df["tags"].fillna("")+" "+df["short_desc"].fillna("")).str.lower() | |
return df | |
CATALOG=load_catalog() | |
# ----------------------------------------------------------------------------- | |
# Embedding bank with on-disk caching | |
# ----------------------------------------------------------------------------- | |
class EmbeddingBank: | |
def __init__(s, docs, model_id, dataset_tag): | |
s.model_id=model_id; s.dataset_tag=dataset_tag; s.model=SentenceTransformer(model_id); s.embs=s._load_or_build(docs) | |
def _cache_path(s,n): return os.path.join(EMBED_CACHE_DIR, f"emb_{hashlib.md5((s.dataset_tag+'|'+s.model_id+f'|{n}').encode()).hexdigest()[:10]}.npy") | |
def _load_or_build(s,docs): | |
p=s._cache_path(len(docs)) | |
if os.path.exists(p): | |
embs=np.load(p,mmap_mode="r"); | |
if embs.shape[0]==len(docs): return embs | |
embs=s.model.encode(docs, convert_to_numpy=True, normalize_embeddings=True, show_progress_bar=True) | |
np.save(p, embs); return np.load(p, mmap_mode="r") | |
def query_vec(s,text): return s.model.encode([text], convert_to_numpy=True, normalize_embeddings=True)[0] | |
EMB=EmbeddingBank(CATALOG["doc"].tolist(), EMBED_MODEL_ID, DATASET_ID) | |
# Token set for light lexical checks (used by interest Hit@k) | |
_tok_rx = re.compile(r"[a-z0-9][a-z0-9\-']*") | |
if "tok_set" not in CATALOG.columns: | |
CATALOG["tok_set"]=(CATALOG["name"].fillna("")+" "+CATALOG["tags"].fillna("")+" "+CATALOG["short_desc"].fillna("")).map(lambda t:set(_tok_rx.findall(str(t).lower()))) | |
# Optional cross-encoder for re-ranking (small CPU-friendly model by default) | |
try: | |
from sentence_transformers import CrossEncoder | |
except: | |
CrossEncoder=None | |
RERANK_MODEL_ID=os.getenv("RERANK_MODEL_ID","cross-encoder/ms-marco-MiniLM-L-6-v2") | |
_CE_MODEL=None | |
def _load_cross_encoder(): | |
global _CE_MODEL | |
if _CE_MODEL is None and CrossEncoder is not None: | |
_CE_MODEL=CrossEncoder(RERANK_MODEL_ID, device="cpu") | |
return _CE_MODEL | |
# Occasion-specific keyword priors (light bonus shaping) | |
OCCASION_PRIORS={"valentines":[("jewelry",.12),("chocolate",.10),("candle",.08),("romantic",.08),("couple",.08),("heart",.06)], | |
"birthday":[("fun",.06),("game",.06),("personalized",.06),("gift set",.05),("surprise",.04)], | |
"anniversary":[("couple",.10),("jewelry",.10),("photo",.08),("frame",.06),("memory",.06),("candle",.06)], | |
"graduation":[("journal",.10),("planner",.08),("office",.08),("coffee",.06),("motivation",.06)], | |
"housewarming":[("home",.10),("kitchen",.08),("decor",.10),("candle",.06),("serving",.06)], | |
"new_baby":[("baby",.12),("nursery",.10),("soft",.06),("blanket",.06)], | |
"retirement":[("relax",.08),("hobby",.08),("travel",.06),("book",.06)], | |
"holidays":[("holiday",.10),("winter",.08),("chocolate",.08),("cozy",.06),("family",.06)], | |
"promotion":[("desk",.10),("office",.10),("premium",.08),("organizer",.06)], | |
"get_well":[("cozy",.10),("tea",.08),("soothing",.06),("care",.06)]} | |
def expand_with_synonyms(tokens: List[str])->List[str]: | |
# Expand user-provided interests with synonyms to enrich the query | |
out=[]; | |
for t in tokens: | |
t=t.strip().lower() | |
if t: out+=[t]+SYNONYMS.get(t,[]) | |
return out | |
def profile_to_query(p:Dict)->str: | |
# Construct a dense query string from profile information | |
inter=[i.lower() for i in p.get("interests",[]) if i]; expanded=expand_with_synonyms(inter)*3 | |
parts=[", ".join(expanded) if expanded else "", ", ".join(REL_TO_TOKENS.get(p.get("relationship","Friend"),[])), OCCASION_CANON.get(p.get("occ_ui","Birthday"),"birthday")] | |
tail=f"gift ideas for a {p.get('relationship','Friend')} for {parts[-1]}; likes {', '.join(inter) or 'general'}" | |
return " | ".join([x for x in parts if x])+" | "+tail | |
def _gender_ok_mask(g:str)->np.ndarray: | |
# Gender-aware filter: exclude items explicitly labeled for the opposite gender unless unisex | |
g=(g or "any").lower(); bl=CATALOG["blob"] | |
has_m=bl.str.contains(r"\b(men|man's|mens|male|for men)\b",regex=True,na=False) | |
has_f=bl.str.contains(r"\b(women|woman's|womens|female|for women|dress)\b",regex=True,na=False) | |
has_u=bl.str.contains(r"\bunisex|gender neutral\b",regex=True,na=False) | |
if g=="female": return (~has_m | has_u).to_numpy() | |
if g=="male": return (~has_f | has_u).to_numpy() | |
return np.ones(len(bl),bool) | |
def _mask_by_age(age:str, blob:pd.Series)->np.ndarray: | |
# Age-aware filter: crude regex to separate kids/teens/adults | |
kids=blob.str.contains(r"\b(?:kid|kids|child|children|toddler|baby|boys?|girls?|kid's|children's)\b",regex=True,na=False) | |
teen=blob.str.contains(r"\b(?:teen|teens|young adult|ya)\b",regex=True,na=False) | |
if age in ("adult","senior"): return (~kids).to_numpy() | |
if age=="teens": return ((~kids)|teen).to_numpy() | |
if age=="kids": return (kids | (~teen & kids)).to_numpy() | |
return np.ones(len(blob),bool) | |
def _interest_bonus(p:Dict, idx:np.ndarray)->np.ndarray: | |
# Soft bonus if catalog tokens overlap with interest vocabulary (synonyms included) | |
ints=[i.lower() for i in p.get("interests",[]) if i]; syns=[s for it in ints for s in SYNONYMS.get(it,[])]; vocab=set(ints+syns) | |
if not vocab or idx.size==0: return np.zeros(len(idx),"float32") | |
counts=np.array([len(CATALOG["tok_set"].iat[i] & vocab) for i in idx],"float32"); return .10*np.clip(counts,0,6) | |
def _occasion_bonus(idx:np.ndarray, occ_ui:str)->np.ndarray: | |
# Soft bonus based on occasion priors (keywords found in item blob) | |
pri=OCCASION_PRIORS.get(OCCASION_CANON.get(occ_ui or "Birthday","birthday"),[]) | |
if not pri or idx.size==0: return np.zeros(len(idx),"float32") | |
bl=CATALOG["blob"].to_numpy(); out=np.zeros(len(idx),"float32") | |
for j,i in enumerate(idx): | |
bonus=sum(w for kw,w in pri if kw in bl[i]); out[j]=min(bonus,.15) | |
return out | |
def _minmax(x:np.ndarray)->np.ndarray: | |
# Normalize to [0,1] with safe guard for constant vectors | |
if x.size==0: return x | |
lo,hi=float(np.min(x)),float(np.max(x)); | |
return np.zeros_like(x) if hi<=lo+1e-9 else (x-lo)/(hi-lo) | |
def _mmr_select(cand_idx:np.ndarray, scores:np.ndarray, k:int, lambda_:float=.7)->np.ndarray: | |
# MMR selection to maintain diversity in the final top-k | |
if cand_idx.size<=k: return cand_idx[np.argsort(-scores)][:k] | |
picked=[]; rest=list(range(len(cand_idx))); rel=_minmax(scores) | |
V=np.asarray(EMB.embs,"float32")[cand_idx]; V/=np.linalg.norm(V,axis=1,keepdims=True)+1e-8 | |
while len(picked)<k and rest: | |
if not picked: picked.append(rest.pop(int(np.argmax(rel[rest])))); continue | |
sim_to_sel=np.array([float((V[c]@V[picked].T) if np.ndim(V[c]@V[picked].T)==0 else np.max(V[c]@V[picked].T)) for c in rest],"float32") | |
j=int(np.argmax(lambda_*rel[rest]-(1-lambda_)*sim_to_sel)); picked.append(rest.pop(j)) | |
return cand_idx[np.array(picked,int)] | |
def recommend_top3_budget_first( | |
p: Dict, | |
include_synth: bool = True, | |
synth_n: int = 10, | |
widen_budget_frac: float = 0.5 | |
) -> pd.DataFrame: | |
""" | |
Retrieve โ score โ diversify. Always returns semantically-ranked results | |
from the catalog (no โcheapest-3โ fallback). If strict filters empty the | |
pool, we progressively relax them but still rank by embeddings + bonuses. | |
Optionally appends a 4th 'Generated' item (DIY) when include_synth=True. | |
""" | |
# ---------- Filters (progressive relaxations) ---------- | |
lo, hi = float(p.get("budget_min", 0)), float(p.get("budget_max", 1e9)) | |
blob = CATALOG["blob"] | |
price = CATALOG["price_usd"].values | |
age_ok = _mask_by_age(p.get("age_range", "any"), blob) | |
gen_ok = _gender_ok_mask(p.get("gender", "any")) | |
price_ok_strict = (price >= lo) & (price <= hi) | |
price_ok_wide = (price >= max(0, lo * (1 - widen_budget_frac))) & \ | |
(price <= (hi * (1 + widen_budget_frac) if hi < 1e8 else hi)) | |
mask_chain = [ | |
price_ok_strict & age_ok & gen_ok, # ืืื ืงืฉืื | |
price_ok_strict & gen_ok, # ืืื ืืื | |
price_ok_wide & gen_ok, # ืืจืืืช ืืืื ืชืงืฆืื | |
age_ok & gen_ok, # ืืื ืชืงืฆืื | |
gen_ok, # ืจืง ืืืืจ | |
np.ones(len(CATALOG), bool), # ืืื | |
] | |
idx = np.array([], dtype=int) | |
for m in mask_chain: | |
cand = np.where(m)[0] | |
if cand.size: | |
idx = cand | |
break | |
# ---------- Query & base similarities ---------- | |
q = profile_to_query(p) | |
qv = EMB.query_vec(q).astype("float32") | |
embs = np.asarray(EMB.embs, "float32") | |
emb_sims = embs[idx] @ qv | |
# ---------- Bonuses (ืขืืืื ืืืืฉืืื ืขื ืืืืขืืืื ืฉื ืืืจื) ---------- | |
target = (lo + hi) / 2.0 if hi > lo else hi | |
prices = CATALOG.iloc[idx]["price_usd"].to_numpy() | |
price_bonus = np.clip(.12 - np.abs(prices - target) / max(target, 1.0), 0, .12).astype("float32") | |
int_bonus = _interest_bonus(p, idx) | |
occ_bonus = _occasion_bonus(idx, p.get("occ_ui", "Birthday")) | |
# Pre-score ืขื ืืื ืืช ื-NaN/Inf | |
pre = np.nan_to_num(emb_sims + price_bonus + int_bonus + occ_bonus, nan=0.0, posinf=0.0, neginf=0.0) | |
# ---------- Local candidate pool ---------- | |
K1 = max(1, min(48, idx.size)) | |
try: | |
top_local = np.argpartition(-pre, K1 - 1)[:K1] | |
except Exception: | |
top_local = np.argsort(-pre)[:K1] | |
cand_idx = idx[top_local] | |
# ---------- Feature normalization ---------- | |
emb_n = _minmax(np.nan_to_num(emb_sims[top_local], nan=0.0)) | |
price_n = _minmax(np.nan_to_num(price_bonus[top_local],nan=0.0)) | |
int_n = _minmax(np.nan_to_num(int_bonus[top_local], nan=0.0)) | |
occ_n = _minmax(np.nan_to_num(occ_bonus[top_local], nan=0.0)) | |
# ---------- Optional cross-encoder ---------- | |
ce = _load_cross_encoder() | |
if ce is not None: | |
docs = CATALOG.loc[cand_idx, "doc"].tolist() | |
pairs = [(q, d) for d in docs] | |
k_ce = min(24, len(pairs)) | |
tl = np.argpartition(-emb_n, k_ce - 1)[:k_ce] | |
ce_raw = np.array(ce.predict([pairs[i] for i in tl]), "float32") | |
ce_n = np.zeros_like(emb_n) | |
ce_n[tl] = _minmax(ce_raw) | |
else: | |
ce_n = np.zeros_like(emb_n) | |
# ---------- Final score ---------- | |
final = np.nan_to_num(.56*emb_n + .26*ce_n + .10*int_n + .05*occ_n + .03*price_n, nan=0.0) | |
# ---------- Select top-3 with diversity ---------- | |
k = int(min(3, cand_idx.size)) | |
pick = _mmr_select(cand_idx, final, k=k) if k > 0 else np.array([], dtype=int) | |
if pick.size == 0: | |
pick = cand_idx[np.argsort(-final)[:min(3, cand_idx.size)]] | |
# ---------- Build result ---------- | |
res = CATALOG.loc[pick].copy() | |
pos = {int(cand_idx[i]): i for i in range(len(cand_idx))} | |
res["similarity"] = [float(final[pos[int(i)]]) if int(i) in pos else np.nan for i in pick] | |
# ---------- Optional synthetic #4 ---------- | |
if include_synth: | |
try: | |
synth = pick_best_synthetic(p, qv, generate_synthetic_candidates(p, n=int(max(1, synth_n)))) | |
if synth is not None: | |
res = pd.concat( | |
[res, pd.DataFrame([synth])[["name","short_desc","price_usd","image_url","similarity"]]], | |
ignore_index=True | |
) | |
except Exception: | |
pass # ืื ืฉืืืจืื ืืช ื-UI ืื ื-DIY ื ืืฉื | |
return res[["name","short_desc","price_usd","image_url","similarity"]].reset_index(drop=True) | |
q=profile_to_query(p); qv=EMB.query_vec(q).astype("float32") | |
emb_sims=np.asarray(EMB.embs,"float32")[idx]@qv | |
target=(lo+hi)/2.0 if hi>lo else hi; prices=CATALOG.iloc[idx]["price_usd"].to_numpy() | |
# Small bonus for being close to the budget mid-point | |
price_bonus=np.clip(.12-np.abs(prices-target)/max(target,1.0),0,.12).astype("float32") | |
int_bonus=_interest_bonus(p,idx); occ_bonus=_occasion_bonus(idx,p.get("occ_ui","Birthday")) | |
pre=emb_sims+price_bonus+int_bonus+occ_bonus | |
# Keep a local candidate pool for cost/quality tradeoff | |
K1=min(48,idx.size); top_local=np.argpartition(-pre,K1-1)[:K1]; cand_idx=idx[top_local] | |
emb_n=_minmax(emb_sims[top_local]); price_n=_minmax(price_bonus[top_local]); int_n=_minmax(int_bonus[top_local]); occ_n=_minmax(occ_bonus[top_local]) | |
ce=_load_cross_encoder(); | |
if ce is not None: | |
# Optional cross-encoder re-ranking on a smaller slice | |
docs=CATALOG.loc[cand_idx,"doc"].tolist(); pairs=[(q,d) for d in docs] | |
k_ce=min(24,len(pairs)); tl=np.argpartition(-emb_n,k_ce-1)[:k_ce]; ce_raw=np.array(ce.predict([pairs[i] for i in tl]),"float32"); ce_n=np.zeros_like(emb_n); ce_n[tl]=_minmax(ce_raw) | |
else: | |
ce_n=np.zeros_like(emb_n) | |
# Final weighted score (tuned manually) | |
final=(.56*emb_n+.26*ce_n+.10*int_n+.05*occ_n+.03*price_n).astype("float32") | |
pick=_mmr_select(cand_idx,final,k=min(3,cand_idx.size)) | |
res=CATALOG.loc[pick].copy(); pos={int(cand_idx[i]):i for i in range(len(cand_idx))}; res["similarity"]=[float(final[pos[int(i)]]) for i in pick] | |
# === NEW: synthetic #4 === | |
synth = pick_best_synthetic(p, qv, generate_synthetic_candidates(p, n=10)) | |
if synth is not None: | |
res = pd.concat( | |
[res, pd.DataFrame([synth])[["name","short_desc","price_usd","image_url","similarity"]]], | |
ignore_index=True | |
) | |
return res[["name","short_desc","price_usd","image_url","similarity"]].reset_index(drop=True) | |
# ===== DIY (FLAN-only) ===== | |
DIY_MODEL_ID=os.getenv("DIY_MODEL_ID","google/flan-t5-small"); DIY_DEVICE=torch.device("cpu") | |
MAX_INPUT_TOKENS=int(os.getenv("MAX_INPUT_TOKENS","384")); DIY_MAX_NEW_TOKENS=int(os.getenv("DIY_MAX_NEW_TOKENS","120")) | |
# Light aliases to seed the DIY gift title with an interest token | |
INTEREST_ALIASES={"Reading":["book","novel","literary"],"Fashion":["style","chic","silk"],"Home decor":["candle","wall","jar"],"Technology":["tech","gadget","usb"],"Movies":["film","cinema","poster"]} | |
FALLBACK_NOUNS=["Kit","Set","Bundle","Box","Pack"] | |
_diy_cache_model={} | |
def _load_flan(mid:str): | |
# Lazy-load and cache FLAN-T5 on CPU | |
if mid in _diy_cache_model: return _diy_cache_model[mid] | |
tok=AutoTokenizer.from_pretrained(mid, use_fast=True, trust_remote_code=True) | |
mdl=AutoModelForSeq2SeqLM.from_pretrained(mid, trust_remote_code=True, use_safetensors=True).to(DIY_DEVICE).eval() | |
_diy_cache_model[mid]=(tok,mdl); return _diy_cache_model[mid] | |
def _gen(tok, mdl, prompt, max_new_tokens=64, do_sample=False, temperature=.9, top_p=.95, seed=None): | |
# Small wrapper for deterministic/non-deterministic generation | |
if seed is None: seed=random.randint(1,10_000_000) | |
random.seed(seed); torch.manual_seed(seed) | |
enc=tok(prompt, truncation=True, max_length=MAX_INPUT_TOKENS, return_tensors="pt"); enc={k:v.to(DIY_DEVICE) for k,v in enc.items()} | |
out=mdl.generate(**enc, max_new_tokens=max_new_tokens, eos_token_id=tok.eos_token_id, pad_token_id=tok.eos_token_id, **({"do_sample":True,"temperature":temperature,"top_p":top_p} if do_sample else {"do_sample":False,"num_beams":1})) | |
return tok.decode(out[0], skip_special_tokens=True).strip() | |
def _choose_interest_token(interests): | |
# Pick a representative token to inject into the DIY name | |
for it in interests: | |
if INTEREST_ALIASES.get(it): return random.choice(INTEREST_ALIASES[it]) | |
return (interests[0].split()[0].lower() if interests else "gift") | |
def _title_case(s): s=re.sub(r'\s+',' ',s).strip(); s=re.sub(r'["โโโโ]+','',s); return " ".join([w.capitalize() for w in s.split()]) | |
def _sanitize_name(name, interests): | |
# Clean LLM-proposed name and enforce a short, interest-infused title | |
for b in [r"^the name\b",r"\bmember of the family\b",r"^name\b",r"^title\b"]: name=re.sub(b,"",name,flags=re.I).strip() | |
name=re.sub(r'[:\-โโ]+$',"",name).strip(); alias=_choose_interest_token(interests) | |
if alias not in name.lower(): | |
tokens=[t for t in re.split(r"[\s\-]+",name) if t] | |
name=(f"{alias.capitalize()} "+(" ".join([t.capitalize() for t in tokens]) if tokens else random.choice(FALLBACK_NOUNS))) if len(tokens)<4 else " ".join([tokens[0],alias.capitalize(),*tokens[1:]]) | |
name=re.sub(r'\b(Home Decor:?\s*){2,}','Home Decor ',name,flags=re.I); name=_title_case(name)[:80] | |
if len(name.split())<3: name=f"{alias.capitalize()} {random.choice(FALLBACK_NOUNS)}" | |
return name | |
def _split_list_text(s,seps): | |
# Parse list-like text returned by LLM into clean items (fallback across separators) | |
s=s.strip() | |
for sep in seps: | |
if sep in s: | |
parts=[p.strip(" -โข*.,;:") for p in s.split(sep) if p.strip(" -โข*.,;:")] | |
if len(parts)>=2: return parts | |
return [p.strip(" -โข*.,;:") for p in re.split(r"[\n\r;]+", s) if p.strip(" -โข*.,;:")] | |
def _coerce_materials(items): | |
# Normalize materials list: dedupe, keep short, ensure quantities, pad with basics | |
out=[] | |
for it in items: | |
it=re.sub(r'\s+',' ',it).strip(" -โข*.,;:"); | |
if not it: continue | |
it=re.sub(r'(\b\w+\b)(?:\s+\1){2,}',r'\1',it,flags=re.I) | |
if len(it)>60: it=it[:58]+"โฆ" | |
if not re.search(r"\d",it): it+=" x1" | |
if it.lower() not in [x.lower() for x in out]: out.append(it) | |
if len(out)>=8: break | |
base=["Small gift box x1","Decorative paper x2","Twine 2 m","Cardstock sheets x2","Double-sided tape x1","Stickers x8","Ribbon 1 m","Fine-tip marker x1"] | |
for b in base: | |
if len(out)>=6: break | |
if b.lower() not in [x.lower() for x in out]: out.append(b) | |
return out[:8] | |
def _coerce_steps(items): | |
# Normalize step list: trim, remove numbering, enforce sentence case, pad to 6+ | |
out=[] | |
for it in items: | |
it=it.strip(" -โข*.,;:"); | |
if not it: continue | |
it=re.sub(r'\s+',' ',it); | |
if len(it)>120: it=it[:118]+"โฆ" | |
it=re.sub(r'^(?:\d+[\).\s-]*)','',it); it=it[0].upper()+it[1:] if it else it; out.append(it) | |
if len(out)>=8: break | |
while len(out)<6: out.append(f"Refine and decorate step {len(out)+1}") | |
return out[:8] | |
def _only_int(s): m=re.search(r"-?\d+",s); return int(m.group()) if m else None | |
def _clamp_num(v,lo,hi,default): | |
# Clamp numeric values into a valid range; fallback to default or midpoint | |
try: x=float(v); return int(min(max(x,lo),hi)) | |
except: return int((lo+hi)/2 if default is None else default) | |
def diy_generate(profile:Dict)->Tuple[dict,str]: | |
# Generate a DIY gift object (name, overview, materials, steps, cost, time) | |
tok,mdl=_load_flan(DIY_MODEL_ID) | |
p={"recipient_name":profile.get("recipient_name","Recipient"),"relationship":profile.get("relationship","Friend"), | |
"occ_ui":profile.get("occ_ui","Birthday"),"occasion":profile.get("occ_ui","Birthday"),"interests":profile.get("interests",[]), | |
"budget_min":int(float(profile.get("budget_min",10))),"budget_max":int(float(profile.get("budget_max",100))), | |
"age_range":profile.get("age_range","any"),"gender":profile.get("gender","any")} | |
lang="English"; ints_str=", ".join(p["interests"]) or "general" | |
prompt_name=(f"Return ONLY a DIY gift NAME in Title Case (4โ8 words). Must include at least one interest token from: " | |
f"{', '.join(sum(([it]+INTEREST_ALIASES.get(it,[]) for it in p['interests']), [])) or 'gift'}. " | |
f"Occasion: {p['occ_ui']}. Relationship: {p['relationship']}. Language: {lang}. Forbidden: the words 'name','title','family'. " | |
"No quotes, no trailing punctuation.\nExamples:\nReading โ Literary Candle Bookmark Kit\nTechnology โ Gadget Cable Organizer Set\nHome decor โ Rustic Jar Candle Bundle\nOutput:") | |
name=_sanitize_name(_gen(tok,mdl,prompt_name, max_new_tokens=24, do_sample=False), p["interests"]) | |
overview=_gen(tok,mdl,(f"Write EXACTLY 2 sentences in {lang} for a handmade gift called '{name}'. Mention {p['recipient_name']} " | |
f"({p['relationship']}) and the occasion ({p['occ_ui']}). Explain how it reflects the interests: {ints_str}. " | |
"No lists, no emojis. Output only the two sentences."), max_new_tokens=80, do_sample=True, temperature=.9, top_p=.95) | |
materials=_split_list_text(_gen(tok,mdl,(f"List 6 concise materials with quantities to make '{name}' cheaply. Keep total within " | |
f"{p['budget_min']}-{p['budget_max']} USD. Output ONLY a comma-separated list."), max_new_tokens=96, do_sample=False), [",",";"]) | |
steps=_split_list_text(_gen(tok,mdl,(f"Write 6 short imperative steps to make '{name}'. Output ONLY a semicolon-separated list."), max_new_tokens=120, do_sample=True, temperature=.9, top_p=.95), [";","\n"]) | |
cost=_only_int(_gen(tok,mdl,(f"Return ONE integer total cost in USD between {p['budget_min']}-{p['budget_max']}. Output NUMBER only."), max_new_tokens=6, do_sample=False)) | |
minutes=_only_int(_gen(tok,mdl,"Return ONE integer minutes between 20 and 180. Output NUMBER only.", max_new_tokens=6, do_sample=False)) | |
idea={"gift_name":name,"overview":overview,"materials_needed":_coerce_materials(materials),"steps":_coerce_steps(steps), | |
"estimated_cost_usd":_clamp_num(cost,p["budget_min"],p["budget_max"],None),"estimated_time_minutes":_clamp_num(minutes,20,180,60)} | |
return idea,"ok" | |
def generate_synthetic_candidates(profile, n=10): | |
# Use FLAN-based DIY generator to create N lightweight candidates (name/overview/price) | |
cands = [] | |
lo, hi = int(float(profile.get("budget_min", 10))), int(float(profile.get("budget_max", 100))) | |
for _ in range(n): | |
idea, _ = diy_generate(profile) # Already returns name/overview/estimated_cost | |
price = int(idea.get("estimated_cost_usd") or random.randint(lo, hi)) | |
name = idea.get("gift_name", "Custom DIY Gift")[:160] | |
desc = (idea.get("overview", "") or "").strip()[:300] | |
doc = f"{name} | custom | {desc}".lower() | |
cands.append({"name": name, "short_desc": desc, "price_usd": price, "image_url": "", "doc": doc}) | |
return cands | |
def pick_best_synthetic(profile, qv, candidates): | |
# Embed synthetic candidates and pick the one most similar to the query vector | |
if not candidates: return None | |
docs = [c["doc"] for c in candidates] | |
vecs = EMB.model.encode(docs, convert_to_numpy=True, normalize_embeddings=True) | |
sims = vecs @ qv | |
j = int(np.argmax(sims)) | |
best = candidates[j].copy() | |
best["similarity"] = float(sims[j]) | |
return best | |
# --------------------- Personalized Message (FLAN + validation) --------------------- | |
# Implementation ported from the Colab; tone-specific constraints + simple checks. | |
MSG_MODEL_ID = "google/flan-t5-small" | |
MSG_DEVICE = "cpu" | |
TEMP_RANGE = (0.88, 1.10) | |
TOPP_RANGE = (0.90, 0.96) | |
REP_PENALTY = 1.12 | |
MSG_MAX_NEW_TOKENS = 90 | |
MSG_MAX_TRIES = 4 | |
_last_msg: Optional[str] = None | |
_msg_tok, _msg_mdl = None, None | |
TONE_STYLES: Dict[str, Dict[str, List[str]]] = { | |
"Formal": { | |
"system": "Write 2โ3 refined sentences with professional courtesy and clarity.", | |
"rules": [ | |
"You may begin with 'Dear {name},' but keep it concise.", | |
"Use precise vocabulary; avoid colloquialisms.", | |
"Conclude with a dignified line." | |
], | |
}, | |
"Casual": { | |
"system": "Write 2โ3 relaxed sentences with natural, friendly language.", | |
"rules": [ | |
"Keep it light and conversational.", | |
"Reference one concrete interest detail.", | |
"End upbeat without clichรฉs." | |
], | |
}, | |
"Funny": { | |
"system": "Write 2โ3 witty sentences with playful humor.", | |
"rules": [ | |
"Add one subtle pun linked to the occasion or interests.", | |
"No slapstick; keep it tasteful.", | |
"End with a cheeky nudge." | |
], | |
}, | |
"Heartfelt": { | |
"system": "Write 2โ3 warm, sincere sentences with genuine sentiment.", | |
"rules": [ | |
"Open with an image or specific detail; avoid templates.", | |
"Let one verb carry the energy; minimal adjectives.", | |
"Close with a crisp, personal wish." | |
], | |
}, | |
"Inspirational": { | |
"system": "Write 2โ3 uplifting sentences with forward-looking energy.", | |
"rules": [ | |
"Honor a trait or effort implied by the interests.", | |
"Use a subtle metaphor; avoid grandiose platitudes.", | |
"Finish with a compact, future-facing line." | |
], | |
}, | |
"Playful": { | |
"system": "Write 2โ3 lively sentences with bounce and rhythm.", | |
"rules": [ | |
"Sneak a gentle internal rhyme or alliteration.", | |
"Keep syntax varied and musical.", | |
"Land on a spirited close." | |
], | |
}, | |
"Romantic": { | |
"system": "Write 2โ3 intimate sentences, warm and elegant.", | |
"rules": [ | |
"Reference a shared moment or interest; keep it subtle.", | |
"No clichรฉs or over-sweet phrasing.", | |
"End with a soft, affectionate note." | |
], | |
}, | |
"Appreciative": { | |
"system": "Write 2โ3 sentences that express genuine appreciation.", | |
"rules": [ | |
"Name a specific quality or habit tied to the interests.", | |
"Avoid business thank-you clichรฉs.", | |
"Close with concise gratitude." | |
], | |
}, | |
"Encouraging": { | |
"system": "Write 2โ3 supportive sentences that motivate gently.", | |
"rules": [ | |
"Acknowledge progress or perseverance (hinted by interests).", | |
"Offer one practical, hopeful sentiment.", | |
"Finish with a compact encouragement." | |
], | |
}, | |
} | |
BAN_PHRASES = [ | |
] | |
OPENERS = [ | |
"Hereโs to a moment that fits you perfectly:", | |
"A note made just for you:", | |
"Because you make celebrations easy to love:", | |
"For a day that sounds like you:", | |
] | |
CLOSERS = [ | |
"Enjoy every bitโyouโve earned it.", | |
"Keep doing the things that light you up.", | |
"Hereโs to more of what makes you, you.", | |
"Let this be a spark for the year ahead.", | |
] | |
def _msg_load(): | |
# Lazy-load FLAN for message generation (CPU) | |
global _msg_tok, _msg_mdl | |
if _msg_tok is None or _msg_mdl is None: | |
_msg_tok = AutoTokenizer.from_pretrained(MSG_MODEL_ID) | |
_msg_mdl = AutoModelForSeq2SeqLM.from_pretrained(MSG_MODEL_ID) | |
_msg_mdl.to(MSG_DEVICE).eval() | |
return _msg_tok, _msg_mdl | |
def _norm(s: str) -> str: | |
# Collapse whitespace for more reliable validators | |
return re.sub(r"\s+", " ", s or "").strip() | |
def _sentences_n(s: str) -> int: | |
# Count sentences via punctuation boundaries | |
return len([p for p in re.split(r"(?<=[.!?])\s+", s.strip()) if p]) | |
def _contains_any(text: str, terms: List[str]) -> bool: | |
# Case-insensitive containment check for any of the given terms | |
t = text.lower() | |
return any(term for term in terms if term) and any((term or "").lower() in t for term in terms) | |
def _too_similar(a: str, b: str, n=3, thr=0.85) -> bool: | |
# Approximate de-duplication via n-gram Jaccard similarity | |
def ngrams(txt): | |
toks = re.findall(r"[a-zA-Z']+", txt.lower()) | |
return set(tuple(toks[i:i+n]) for i in range(max(0, len(toks)-n+1))) | |
A, B = ngrams(a), ngrams(b) | |
if not A or not B: return False | |
j = len(A & B) / max(1, len(A | B)) | |
return j >= thr | |
def _clean_occasion(occ: str) -> str: | |
# Normalize typographic apostrophes to ASCII and trim | |
return (occ or "").replace("โ","'").strip() | |
def _build_prompt(profile: Dict[str, Any]) -> Tuple[str, Dict[str,str]]: | |
# Compose a guided prompt (tone + micro-rules) for the message LLM | |
name = profile.get("recipient_name", "Friend") | |
rel = profile.get("relationship", "Friend") | |
occ = _clean_occasion(profile.get("occ_ui") or profile.get("occasion") or "Birthday") | |
tone = profile.get("tone", "Heartfelt") | |
ints = ", ".join(profile.get("interests", [])) or "general interests" | |
style = TONE_STYLES.get(tone, TONE_STYLES["Heartfelt"]) | |
opener = random.choice(OPENERS) | |
closer = random.choice(CLOSERS) | |
spice = random.choice([ | |
"Use one concrete visual detail.", | |
"Shift the rhythm slightly in the second sentence.", | |
"Let one verb carry most of the energy; keep adjectives minimal.", | |
"Add a gentle internal rhyme." | |
]) | |
lines = [ | |
"Generate a short gift-card message in English (2โ3 sentences).", | |
f"Recipient: {name} ({rel}). Occasion: {occ}. Interests: {ints}. Tone: {tone}.", | |
style["system"], | |
"Rules:", | |
*[f"- {r}" for r in style["rules"]], | |
"- No emojis. No bullet points.", | |
f"- Start with: \"{opener}\" (continue naturally, not as a header).", | |
f"- End with a natural line similar to: \"{closer}\" (rephrase; do not quote).", | |
f"- {spice}", | |
"Output only the message; no extra commentary.", | |
] | |
return "\n".join(lines), dict(name=name, occ=occ) | |
def generate_personal_message(profile: Dict[str, Any], seed: Optional[int]=None, previous_message: Optional[str]=None) -> Dict[str, Any]: | |
# Sample multiple generations with slight sampling variance, validate, and return best | |
global _last_msg | |
tok, mdl = _msg_load() | |
if seed is None: | |
seed = random.randint(1, 10_000_000) | |
tried = [] | |
for attempt in range(1, MSG_MAX_TRIES+1): | |
random.seed(seed); torch.manual_seed(seed) | |
prompt, need = _build_prompt(profile) | |
temp = random.uniform(*TEMP_RANGE) | |
topp = random.uniform(*TOPP_RANGE) | |
enc = tok(prompt, truncation=True, max_length=512, return_tensors="pt").to(MSG_DEVICE) | |
out_ids = mdl.generate( | |
**enc, | |
do_sample=True, | |
temperature=temp, | |
top_p=topp, | |
max_new_tokens=MSG_MAX_NEW_TOKENS, | |
repetition_penalty=REP_PENALTY, | |
pad_token_id=tok.eos_token_id, | |
eos_token_id=tok.eos_token_id, | |
) | |
text = _norm(tok.decode(out_ids[0], skip_special_tokens=True)) | |
# ===== Validators (mirrors the Colab logic) ===== | |
ok_len = 1 <= _sentences_n(text) <= 3 | |
name_ok = _contains_any(text, [need["name"].lower()]) | |
occ_ok = _contains_any(text, [need["occ"].lower(), need["occ"].split()[0].lower()]) | |
ban_ok = not _contains_any(text, BAN_PHRASES) | |
prev = previous_message or _last_msg | |
dup_ok = (prev is None) or (not _too_similar(text, prev, n=3, thr=0.85)) | |
if all([ok_len, name_ok, occ_ok, ban_ok, dup_ok]): | |
_last_msg = text | |
return {"message": text, "meta": {"tone": profile.get("tone","Heartfelt"), | |
"temperature": round(temp,2), "top_p": round(topp,2), | |
"seed": seed, "attempt": attempt, "model": MSG_MODEL_ID}} | |
tried.append({"text": text}); seed += 17 | |
# Fallback if all attempts failed validation | |
fallback = tried[-1]["text"] if tried else f"Happy {(_clean_occasion(profile.get('occ_ui') or 'day')).lower()}, {profile.get('recipient_name','Friend')}!" | |
_last_msg = fallback | |
return {"message": fallback, "meta": {"failed": True, "model": MSG_MODEL_ID, "tone": profile.get("tone","Heartfelt")}} | |
# --------------------- END Personalized Message --------------------- | |
# ===== Rendering & UI ===== | |
def first_sentence(s,max_chars=140): | |
# Extract the first sentence or truncate; keeps the HTML cards compact | |
s=(s or "").strip(); | |
if not s: return "" | |
cut=s.split(". ")[0]; | |
return cut if len(cut)<=max_chars else cut[:max_chars-1]+"โฆ" | |
def render_top3_html(df, age_label): | |
# Render the 3 catalog picks plus the optional 4th "Generated" item | |
if df is None or df.empty: return "<em>No results found within the current filters.</em>" | |
rows=[] | |
for i, r in df.iterrows(): | |
name=str(r.get("name","")).replace("|","\\|").replace("*","\\*").replace("_","\\_") | |
desc=str(first_sentence(r.get("short_desc",""))).replace("|","\\|").replace("*","\\*").replace("_","\\_") | |
price=r.get("price_usd"); sim=r.get("similarity"); img=r.get("image_url","") or "" | |
price_str=f"${price:.0f}" if pd.notna(price) else "N/A"; sim_str=f"{sim:.3f}" if pd.notna(sim) else "โ" | |
img_html=f'<img src="{img}" alt="" style="width:84px;height:84px;object-fit:cover;border-radius:10px;margin-left:12px;" />' if img else "" | |
tag = "Generated" if i==3 else f"#{i+1}" | |
rows.append(f""" | |
<div style="display:flex;align-items:flex-start;justify-content:space-between;gap:10px;padding:10px;border:1px solid #eee;border-radius:12px;margin-bottom:8px;background:#fff;"> | |
<div style="flex:1;min-width:0;"><div style="font-weight:700;">{name} <span style="font-size:.8em;opacity:.7;">({tag})</span></div> | |
<div style="font-size:0.95em;margin-top:4px;">{desc}</div> | |
<div style="font-size:0.9em;margin-top:6px;opacity:0.8;">Price: <b>{price_str}</b> ยท Age: <code>{age_label}</code> ยท Score: <code>{sim_str}</code></div> | |
</div>{img_html} | |
</div>""") | |
return "\n".join(rows) | |
with gr.Blocks(title="๐ GIfty โ Recommender + DIY", css=""" | |
#explain{opacity:.85;font-size:.92em;margin-bottom:8px;} | |
.gr-dataframe thead{display:none;} | |
.gr-dataframe table{border-collapse:separate!important;border-spacing:0 10px!important;table-layout:fixed;width:100%;} | |
.gr-dataframe tbody tr{cursor:pointer;display:block;background:linear-gradient(180deg,#fff,#fafafa);border-radius:14px;border:1px solid #e9eef5;box-shadow:0 1px 1px rgba(16,24,40,.04),0 1px 2px rgba(16,24,40,.06);padding:10px 12px;transition:transform .06s ease, box-shadow .12s ease, background .12s ease;} | |
.gr-dataframe tbody tr:hover{transform:translateY(-1px);background:#f8fafc;box-shadow:0 3px 10px rgba(16,24,40,.08);} | |
.gr-dataframe tbody tr td{border:0!important;padding:4px 8px!important;vertical-align:middle;font-size:.92rem;line-height:1.3;} | |
.gr-dataframe tbody tr td:nth-child(1){font-weight:700;font-size:1rem;letter-spacing:.2px;} | |
.gr-dataframe tbody tr td:nth-child(2),.gr-dataframe tbody tr td:nth-child(4){opacity:.8;} | |
.gr-dataframe tbody tr td:nth-child(3),.gr-dataframe tbody tr td:nth-child(9),.gr-dataframe tbody tr td:nth-child(6),.gr-dataframe tbody tr td:nth-child(5){display:inline-block;background:#eff4ff;color:#243b6b;border:1px solid #dbe5ff;border-radius:999px;padding:2px 10px!important;font-size:.84rem;margin:2px 6px 2px 0;} | |
.gr-dataframe tbody tr td:nth-child(7),.gr-dataframe tbody tr td:nth-child(8){display:inline-block;background:#f1f5f9;border:1px solid #e2e8f0;color:#0f172a;border-radius:10px;padding:2px 8px!important;font-variant-numeric:tabular-nums;margin:2px 6px 2px 0;} | |
.handsontable .wtBorder,.handsontable .htBorders,.handsontable .wtBorder.current{display:none!important;} | |
.gr-dataframe table td:focus{outline:none!important;box-shadow:none!important;} | |
""") as demo: | |
gr.Markdown(TITLE) | |
gr.Markdown("### Quick examples (click a row to auto-fill)", elem_id="explain") | |
EXAMPLES=[(["Technology","Movies"],"Birthday",25,45,"Daniel","Friend","adult (18โ64)","male","Funny"), | |
(["Art","Reading","Home decor"],"Anniversary",30,60,"Rotem","Romantic partner","adult (18โ64)","female","Romantic"), | |
(["Gaming","Photography"],"Birthday",30,120,"Omer","Family - Sibling","teen (13โ17)","male","Playful"), | |
(["Reading","Art"],"Graduation",15,35,"Maya","Friend","adult (18โ64)","female","Heartfelt"), | |
(["Science","Crafts"],"Holidays",15,30,"Adam","Family - Child","kid (3โ12)","male","Encouraging")] | |
EX_COLS=["Recipient","Relationship","Interests","Occasion","Age group","Gender","Min $","Max $","Tone"] | |
EX_DF=pd.DataFrame([[name,rel," + ".join(interests),occ,age,gender,bmin,bmax,tone] for (interests,occ,bmin,bmax,name,rel,age,gender,tone) in EXAMPLES], columns=EX_COLS) | |
ex_df=gr.Dataframe(value=EX_DF, interactive=False, wrap=True); gr.Markdown("---") | |
with gr.Row(): | |
recipient_name=gr.Textbox(label="Recipient name", value="Daniel") | |
relationship=gr.Dropdown(label="Relationship", choices=RECIPIENT_RELATIONSHIPS, value="Friend") | |
with gr.Row(): | |
occasion=gr.Dropdown(label="Occasion", choices=OCCASION_UI, value="Birthday") | |
age=gr.Dropdown(label="Age group", choices=list(AGE_OPTIONS.keys()), value="adult (18โ64)") | |
gender=gr.Dropdown(label="Recipient gender", choices=GENDER_OPTIONS, value="male") | |
interests=gr.CheckboxGroup(label="Interests (select a few)", choices=INTEREST_OPTIONS, value=["Technology","Movies"], interactive=True) | |
with gr.Row(): | |
budget_min=gr.Slider(label="Min budget (USD)", minimum=5, maximum=500, step=1, value=25) | |
budget_max=gr.Slider(label="Max budget (USD)", minimum=5, maximum=500, step=1, value=45) | |
tone=gr.Dropdown(label="Message tone", choices=MESSAGE_TONES, value="Funny") | |
go=gr.Button("Get GIfty!") | |
gr.Markdown("### ๐ Input summary"); out_summary = gr.HTML(visible=False) | |
gr.Markdown("### ๐ฏ Recommendations"); out_top3=gr.HTML() | |
gr.Markdown("### ๐ ๏ธ DIY Gift"); out_diy_md=gr.Markdown() | |
gr.Markdown("### ๐ Personalized Message"); out_msg=gr.Markdown() | |
run_token=gr.State(0) | |
def _on_example_select(evt: gr.SelectData): | |
# Clicking a row fills the input widgets with that example | |
r=int(evt.index[0] if isinstance(evt.index,(list,tuple)) else evt.index); row=EX_DF.iloc[r]; ints=[s.strip() for s in str(row["Interests"]).split("+")] | |
return (ints,row["Occasion"],int(row["Min $"]),int(row["Max $"]),row["Recipient"],row["Relationship"],row["Age group"],row["Gender"],row["Tone"]) | |
ex_df.select(_on_example_select, outputs=[interests, occasion, budget_min, budget_max, recipient_name, relationship, age, gender, tone]) | |
def render_diy_md(j:dict)->str: | |
# Nicely format the DIY object as markdown | |
if not j: return "_DIY generation failed._" | |
steps=j.get('step_by_step_instructions', j.get('steps', [])) | |
parts = [ | |
f"**{j.get('gift_name','(no name)')}**","", | |
j.get('overview','').strip(),"", | |
"**Materials**","\n".join(f"- {m}" for m in j.get('materials_needed',[])),"", | |
"**Steps**","\n".join(f"{i+1}. {s}" for i,s in enumerate(steps)),"", | |
f"**Estimated cost:** ${j.get('estimated_cost_usd','?')} ยท **Time:** {j.get('estimated_time_minutes','?')} min" | |
] | |
return "\n".join(parts) | |
def input_summary_html(p, age_label): | |
# Render a compact summary of the current input above the results | |
ints = ", ".join(p.get("interests", [])) or "โ" | |
budget = f"${int(float(p.get('budget_min',0)))}โ${int(float(p.get('budget_max',0)))}" | |
name = p.get("recipient_name","Friend"); rel = p.get("relationship","Friend") | |
occ = p.get("occ_ui", "Birthday"); gender = (p.get("gender","any") or "any").capitalize() | |
return f""" | |
<div style="padding:10px 12px;border:1px solid #e2e8f0;border-radius:12px;background:#f8fafc;margin-bottom:8px;"> | |
<div style="display:flex;flex-wrap:wrap;gap:10px;align-items:center;"> | |
<div><b>Recipient:</b> {name} ({rel})</div> | |
<div><b>Occasion:</b> {occ}</div> | |
<div><b>Age:</b> {age_label}</div> | |
<div><b>Gender:</b> {gender}</div> | |
<div><b>Budget:</b> {budget}</div> | |
<div style="flex-basis:100%;height:0;"></div> | |
<div><b>Interests:</b> {ints}</div> | |
</div> | |
</div> | |
""" | |
def _build_profile(ints, occ, bmin, bmax, name, rel, age_label, gender_val, tone_val): | |
# Convert UI widget values into an internal profile dict | |
try: bmin=float(bmin); bmax=float(bmax) | |
except: bmin,bmax=5.0,500.0 | |
if bmin>bmax: bmin,bmax=bmax,bmin | |
return {"recipient_name":name or "Friend","relationship":rel or "Friend","interests":ints or [],"occ_ui":occ or "Birthday","budget_min":bmin,"budget_max":bmax,"age_range":AGE_OPTIONS.get(age_label,"any"),"gender":(gender_val or "any").lower(),"tone":tone_val or "Heartfelt"} | |
def start_run(curr): | |
# Simple monotonic counter to tie together chained events | |
return int(curr or 0) + 1 | |
def predict_summary_only(rt, *args): | |
# args mapping: | |
# 0: interests, 1: occasion, 2: budget_min, 3: budget_max, | |
# 4: recipient_name, 5: relationship, 6: age_label, 7: gender, 8: tone | |
p = _build_profile(*args) | |
return gr.update(value=input_summary_html(p, args[6]), visible=True), rt | |
def predict_recs_only(rt, *args): | |
p = _build_profile(*args) | |
top3 = recommend_top3_budget_first(p, include_synth=False) # ืืืืจ | |
return gr.update(value=render_top3_html(top3, args[6]), visible=True), rt | |
def predict_recs_with_synth(rt, *args): | |
p = _build_profile(*args) | |
synth_n = int(os.getenv("SYNTH_N", "2")) | |
df = recommend_top3_budget_first(p, include_synth=True, synth_n=synth_n) | |
return gr.update(value=render_top3_html(df, args[6]), visible=True), rt | |
def predict_diy_only(rt, *args): | |
p = _build_profile(*args) | |
diy_json, _ = diy_generate(p) | |
return gr.update(value=render_diy_md(diy_json), visible=True), rt | |
def predict_msg_only(rt, *args): | |
p = _build_profile(*args) | |
msg_obj = generate_personal_message(p) | |
return gr.update(value=msg_obj["message"], visible=True), rt | |
ev_start = go.click(start_run, inputs=[run_token], outputs=[run_token], queue=True) | |
# 1) ืกืืืื ืงืื (ืืืืื) | |
ev_start.then( | |
predict_summary_only, | |
inputs=[run_token, interests, occasion, budget_min, budget_max, recipient_name, relationship, age, gender, tone], | |
outputs=[out_summary, run_token], | |
queue=True, | |
) | |
# 2) ืืืืฆืืช ืืืืจืืช (Top-3 ืืื ืกืื ืชืื) | |
recs_fast = ev_start.then( | |
predict_recs_only, | |
inputs=[run_token, interests, occasion, budget_min, budget_max, recipient_name, relationship, age, gender, tone], | |
outputs=[out_top3, run_token], | |
queue=True, | |
) | |
# 3) ืืืฉืื ืกืื ืชืื ืืฉืื ืืืฉื โ ืืจืขื ื ืืช ืืืชื out_top3 ืืฉืืืื | |
recs_fast.then( | |
predict_recs_with_synth, | |
inputs=[run_token, interests, occasion, budget_min, budget_max, recipient_name, relationship, age, gender, tone], | |
outputs=[out_top3, run_token], | |
queue=True, | |
) | |
# 4) DIY ืึพMessage ืืืืืื ืืจืืฅ ืืืงืืื ืึพ(3) | |
ev_start.then( | |
predict_diy_only, | |
inputs=[run_token, interests, occasion, budget_min, budget_max, recipient_name, relationship, age, gender, tone], | |
outputs=[out_diy_md, run_token], | |
queue=True, | |
) | |
ev_start.then( | |
predict_msg_only, | |
inputs=[run_token, interests, occasion, budget_min, budget_max, recipient_name, relationship, age, gender, tone], | |
outputs=[out_msg, run_token], | |
queue=True, | |
) | |
if __name__=="__main__": | |
demo.launch() | |