OGOGOG commited on
Commit
95e169d
·
verified ·
1 Parent(s): e0b8a90

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +290 -0
app.py ADDED
@@ -0,0 +1,290 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import re
3
+ import numpy as np
4
+ import gradio as gr
5
+ from datasets import load_dataset
6
+ from sentence_transformers import SentenceTransformer
7
+
8
+ # ========================
9
+ # Config
10
+ # ========================
11
+ DATASET_ID = "motimmom/cocktails_clean_nobrand"
12
+ EMBED_MODEL = "sentence-transformers/all-MiniLM-L6-v2"
13
+ FLAVOR_BOOST = 0.20
14
+
15
+ # Use the image you uploaded at the root of the Space repo:
16
+ BACKGROUND_IMAGE_URL = "file=bar.jpg" # <-- safest: served by Gradio from your Space files
17
+
18
+ # If you prefer the remote URL, make sure the space name uses the HY-PHEN:
19
+ # BACKGROUND_IMAGE_URL = "https://huggingface.co/spaces/OGOGOG/AI-Bartender/resolve/main/bar.jpg"
20
+
21
+ # If dataset is private, add Space secret HF_TOKEN (read scope)
22
+ HF_READ_TOKEN = os.getenv("HF_TOKEN") or os.getenv("HUGGING_FACE_HUB_TOKEN")
23
+ load_kwargs = {}
24
+ if HF_READ_TOKEN:
25
+ load_kwargs["token"] = HF_READ_TOKEN
26
+ load_kwargs["use_auth_token"] = HF_READ_TOKEN
27
+
28
+ # ========================
29
+ # Base & Flavor tagging rules
30
+ # ========================
31
+ BASE_SPIRITS = {
32
+ "vodka": [r"\bvodka\b"],
33
+ "gin": [r"\bgin\b"],
34
+ "rum": [r"\brum\b", r"\bwhite rum\b", r"\bdark rum\b"],
35
+ "tequila": [r"\btequila\b"],
36
+ "whiskey": [r"\bwhisk(?:e|)y\b", r"\bbourbon\b", r"\bscotch\b", r"\brye\b"],
37
+ "mezcal": [r"\bmezcal\b"],
38
+ "brandy": [r"\bbrandy\b", r"\bcognac\b"],
39
+ "vermouth": [r"\bvermouth\b"],
40
+ "other": [r"\btriple sec\b", r"\bliqueur\b", r"\bcointreau\b", r"\baperol\b", r"\bcampari\b"],
41
+ }
42
+ FLAVORS = {
43
+ "citrus": [r"lime", r"lemon", r"grapefruit", r"orange", r"\bcitrus\b"],
44
+ "sweet": [r"simple syrup", r"\bsugar\b", r"\bhoney\b", r"\bagave\b", r"\bmaple\b", r"\bgrenadine\b", r"\bvanilla\b", r"\bsweet\b"],
45
+ "sour": [r"\bsour\b", r"lemon juice", r"lime juice", r"\bacid\b"],
46
+ "bitter": [r"\bbitter", r"\bamaro\b", r"\bcampari\b", r"\baperol\b"],
47
+ "smoky": [r"\bsmoky\b", r"\bsmoked\b", r"\bmezcal\b", r"\bpeated\b"],
48
+ "spicy": [r"\bspicy\b", r"\bchili\b", r"\bginger\b", r"\bjalapeño\b", r"\bcayenne\b"],
49
+ "herbal": [r"\bmint\b", r"\bbasil\b", r"\brosemary\b", r"\bthyme\b", r"\bherb", r"\bchartreuse\b"],
50
+ "fruity": [r"pineapple", r"cranberr", r"strawberr", r"mango", r"passion", r"peach", r"\bfruit"],
51
+ "creamy": [r"\bcream\b", r"coconut cream", r"\begg white\b", r"\bcreamy\b"],
52
+ "floral": [r"\brose\b", r"\bviolet\b", r"\belderflower\b", r"\blavender\b", r"\bfloral\b"],
53
+ "refreshing": [r"soda water", r"\btonic\b", r"\bhighball\b", r"\bcollins\b", r"\bfizz\b", r"\brefreshing\b"],
54
+ "boozy": [r"\bstirred\b", r"\bmartini\b", r"old fashioned", r"\bboozy\b", r"\bstrong\b"],
55
+ }
56
+ FLAVOR_OPTIONS = list(FLAVORS.keys())
57
+
58
+ # ========================
59
+ # Robust extraction helpers (with measures)
60
+ # ========================
61
+ def _clean(s): return s.strip() if isinstance(s, str) else ""
62
+
63
+ def _norm_measure(s: str) -> str:
64
+ if not isinstance(s, str): return ""
65
+ s = re.sub(r"\s+", " ", s.strip())
66
+ s = re.sub(r"\bml\b", "ml", s, flags=re.I)
67
+ s = re.sub(r"\boz\b", "oz", s, flags=re.I)
68
+ s = re.sub(r"\btsp\b", "tsp", s, flags=re.I)
69
+ s = re.sub(r"\btbsp\b", "tbsp", s, flags=re.I)
70
+ return s
71
+
72
+ def _join_measure_name(measure, name):
73
+ m = _norm_measure(measure)
74
+ n = name.strip() if isinstance(name, str) else ""
75
+ if m and n: return f"{m} {n}"
76
+ return n or m
77
+
78
+ def _split_ingredient_blob(s):
79
+ if not isinstance(s, str): return []
80
+ parts = re.split(r"[,\n;•\-–]+", s)
81
+ return [p.strip() for p in parts if p and p.strip()]
82
+
83
+ def _from_list_of_pairs(val):
84
+ out_disp, out_tokens = [], []
85
+ for x in val:
86
+ if not isinstance(x, (list, tuple)) or len(x) == 0: continue
87
+ if len(x) == 1:
88
+ name = str(x[0]).strip()
89
+ if name: out_disp.append(name); out_tokens.append(name.lower()); continue
90
+ a, b = str(x[0]).strip(), str(x[1]).strip()
91
+ if re.search(r"\d", a) and not re.search(r"\d", b):
92
+ disp = _join_measure_name(a, b); out_disp.append(disp); out_tokens.append(b.lower())
93
+ elif re.search(r"\d", b) and not re.search(r"\d", a):
94
+ disp = _join_measure_name(b, a); out_disp.append(disp); out_tokens.append(a.lower())
95
+ else:
96
+ disp = (a + " " + b).strip(); out_disp.append(disp); out_tokens.append((b if len(b) > len(a) else a).lower())
97
+ return out_disp, out_tokens
98
+
99
+ def _from_list_of_dicts(val):
100
+ out_disp, out_tokens = [], []
101
+ for x in val:
102
+ if not isinstance(x, dict): continue
103
+ name = next((x[k].strip() for k in ["name","ingredient","item","raw","text","strIngredient"] if isinstance(x.get(k), str) and x[k].strip()), None)
104
+ meas = next((x[k].strip() for k in ["measure","qty","quantity","amount","unit","Measure","strMeasure"] if isinstance(x.get(k), str) and x[k].strip()), None)
105
+ if name and meas:
106
+ out_disp.append(_join_measure_name(meas, name)); out_tokens.append(name.lower())
107
+ elif name:
108
+ out_disp.append(name); out_tokens.append(name.lower())
109
+ return out_disp, out_tokens
110
+
111
+ def _ingredients_from_any(val):
112
+ if isinstance(val, str):
113
+ lines = _split_ingredient_blob(val)
114
+ tokens = []
115
+ for line in lines:
116
+ parts = re.split(r"\s+", line); idx = 0
117
+ for i, p in enumerate(parts):
118
+ if re.search(r"[A-Za-z]", p): idx = i; break
119
+ tokens.append(" ".join(parts[idx:]).lower())
120
+ return lines, tokens
121
+ if isinstance(val, list) and all(isinstance(x, str) for x in val):
122
+ disp = [x.strip() for x in val if x and x.strip()]
123
+ return disp, [x.lower().strip() for x in disp]
124
+ if isinstance(val, list) and any(isinstance(x, (list, tuple)) for x in val):
125
+ return _from_list_of_pairs(val)
126
+ if isinstance(val, list) and any(isinstance(x, dict) for x in val):
127
+ return _from_list_of_dicts(val)
128
+ return [], []
129
+
130
+ def _get_title(row, cols):
131
+ for k in ["title","name","cocktail_name","drink","Drink","strDrink"]:
132
+ if k in cols and _clean(row.get(k)): return _clean(row[k])
133
+ return "Untitled"
134
+
135
+ def _get_ingredients_with_measures(row, cols):
136
+ if "ingredient_tokens" in cols and row.get("ingredient_tokens"):
137
+ toks = [str(x).strip().lower() for x in row["ingredient_tokens"] if str(x).strip()]
138
+ for mkey in ["measure_tokens","measures","measure_list"]:
139
+ if mkey in cols and row.get(mkey) and isinstance(row[mkey], list) and len(row[mkey]) == len(toks):
140
+ disp = []
141
+ for m, n in zip(row[mkey], row["ingredient_tokens"]):
142
+ m = _norm_measure(str(m)); n = str(n).strip()
143
+ disp.append(_join_measure_name(m, n) if m else n)
144
+ return disp, toks
145
+ return toks, toks
146
+ for key in ["ingredients","ingredients_raw","raw_ingredients","Raw_Ingredients","Raw Ingredients","ingredient_list","ingredients_list"]:
147
+ if key in cols and row.get(key) not in (None, "", [], {}): return _ingredients_from_any(row[key])
148
+ return [], []
149
+
150
+ def tag_base(text):
151
+ t = text.lower()
152
+ for base, pats in BASE_SPIRITS.items():
153
+ if any(re.search(p, t) for p in pats): return base
154
+ return "other"
155
+
156
+ def tag_flavors(text):
157
+ t = text.lower(); tags = []
158
+ for flv, pats in FLAVORS.items():
159
+ if any(re.search(p, t) for p in pats): tags.append(flv)
160
+ return tags
161
+
162
+ # ========================
163
+ # Load dataset & build docs
164
+ # ========================
165
+ ds = load_dataset(DATASET_ID, split="train", **load_kwargs)
166
+ cols = ds.column_names
167
+
168
+ DOCS = []
169
+ for r in ds:
170
+ title = _get_title(r, cols)
171
+ ing_disp, ing_tokens = _get_ingredients_with_measures(r, cols)
172
+ ing_disp = [x for x in ing_disp if x]; ing_tokens = [x for x in ing_tokens if x]
173
+ fused = f"{title}\nIngredients: {', '.join(ing_tokens)}"
174
+ DOCS.append({
175
+ "title": title,
176
+ "ingredients_display": ing_disp,
177
+ "ingredients_tokens": ing_tokens,
178
+ "text": fused,
179
+ "base": tag_base(fused),
180
+ "flavors": tag_flavors(fused),
181
+ })
182
+
183
+ # ========================
184
+ # Embeddings
185
+ # ========================
186
+ encoder = SentenceTransformer(EMBED_MODEL)
187
+ doc_embs = encoder.encode([d["text"] for d in DOCS], normalize_embeddings=True, convert_to_numpy=True).astype("float32")
188
+
189
+ # ========================
190
+ # Pretty ingredient formatting
191
+ # ========================
192
+ _MEASURE_RE = re.compile(r"^\s*(?P<meas>(?:\d+(\.\d+)?|\d+\s*/\s*\d+|\d+\s*\d*/\d+)\s*(?:ml|oz|tsp|tbsp)?|\d+\s*(?:ml|oz|tsp|tbsp)|(?:dash|dashes|drop|drops|barspoon)s?)\b[\s\-–:]*", flags=re.I)
193
+
194
+ def _split_measure_name_line(line: str):
195
+ if not isinstance(line, str): return None, line
196
+ m = _MEASURE_RE.match(line.strip())
197
+ if m:
198
+ meas = _norm_measure(m.group("meas")); name = line[m.end():].strip()
199
+ return meas, name or ""
200
+ return "", line.strip()
201
+
202
+ def _format_ingredients_markdown(lines):
203
+ """Bullet points as 'Ingredient (amount)'. Also removes [ and ]."""
204
+ if not lines: return "—"
205
+ formatted = []
206
+ for ln in lines:
207
+ ln = ln.replace("[","").replace("]","")
208
+ meas, name = _split_measure_name_line(ln)
209
+ if name and meas: formatted.append(f"- {name} ({meas})")
210
+ elif name: formatted.append(f"- {name}")
211
+ else: formatted.append(f"- {ln}")
212
+ return "\n".join(formatted)
213
+
214
+ # ========================
215
+ # Recommendation
216
+ # ========================
217
+ def recommend(base_alcohol_text, flavor, top_k=3):
218
+ inferred_base = tag_base(base_alcohol_text or "")
219
+ if flavor not in FLAVOR_OPTIONS: return "Please choose a flavor."
220
+ idxs = [i for i, d in enumerate(DOCS) if d["base"] == inferred_base] or list(range(len(DOCS)))
221
+ q_text = f"Base spirit: {base_alcohol_text}. Flavor: {flavor}. Cocktail recipe."
222
+ q_emb = encoder.encode([q_text], normalize_embeddings=True, convert_to_numpy=True).astype("float32")[0]
223
+ sims = doc_embs[idxs].dot(q_emb)
224
+ scored = []
225
+ for pos, i in enumerate(idxs):
226
+ score = float(sims[pos]) + (FLAVOR_BOOST if flavor in DOCS[i]['flavors'] else 0.0)
227
+ scored.append((score, i))
228
+ scored.sort(reverse=True)
229
+ picks = scored[:max(1,int(top_k))]
230
+ if not picks: return "No matches found."
231
+ blocks = []
232
+ for sc, i in picks:
233
+ d = DOCS[i]
234
+ ing_lines = d["ingredients_display"] or d["ingredients_tokens"]
235
+ ing_md = _format_ingredients_markdown(ing_lines)
236
+ meta = f"**Base:** {d['base']} | **Flavor tags:** {', '.join(d['flavors']) or '—'} | **Score:** {sc:.3f}"
237
+ blocks.append(f"### {d['title']}\n{meta}\n\n**Ingredients:**\n{ing_md}")
238
+ return "\n\n---\n\n".join(blocks)
239
+
240
+ # ========================
241
+ # Background + UI (robust)
242
+ # ========================
243
+ CUSTOM_CSS = f"""
244
+ html, body, #root {{ height: 100%; }}
245
+ /* Background on BODY to avoid component stacking issues */
246
+ body {{
247
+ background-image: url('{BACKGROUND_IMAGE_URL}');
248
+ background-size: cover;
249
+ background-position: center;
250
+ background-attachment: fixed;
251
+ }}
252
+ /* Dark overlay for text contrast */
253
+ body::before {{
254
+ content: "";
255
+ position: fixed;
256
+ inset: 0;
257
+ background: rgba(0,0,0,0.30); /* slightly lighter so image shows */
258
+ z-index: 0;
259
+ }}
260
+ /* Make the app transparent and float above overlay */
261
+ .gradio-container {{ background: transparent !important; position: relative; z-index: 1; }}
262
+ .glass-card {{
263
+ background: rgba(255, 255, 255, 0.08);
264
+ backdrop-filter: blur(6px);
265
+ -webkit-backdrop-filter: blur(6px);
266
+ border-radius: 14px;
267
+ padding: 18px;
268
+ border: 1px solid rgba(255, 255, 255, 0.12);
269
+ }}
270
+ """
271
+
272
+ with gr.Blocks(css=CUSTOM_CSS) as demo:
273
+ with gr.Column(elem_classes=["glass-card"]):
274
+ gr.Markdown("# 🍹 AI Bartender — Type a Base + Flavor")
275
+ with gr.Row():
276
+ base_text = gr.Textbox(value="gin", label="Base alcohol (type any spirit, e.g., 'gin', 'white rum', 'bourbon')")
277
+ flavor = gr.Dropdown(choices=FLAVOR_OPTIONS, value="citrus", label="Flavor")
278
+ topk = gr.Slider(1, 10, value=3, step=1, label="Number of recommendations")
279
+ with gr.Row():
280
+ ex1 = gr.Button("Example: Gin + Citrus")
281
+ ex2 = gr.Button("Example: Rum + Fruity")
282
+ ex3 = gr.Button("Example: Mezcal + Smoky")
283
+ out = gr.Markdown()
284
+ gr.Button("Recommend").click(recommend, [base_text, flavor, topk], out)
285
+ ex1.click(lambda: ("gin", "citrus", 3), outputs=[base_text, flavor, topk])
286
+ ex2.click(lambda: ("white rum", "fruity", 3), outputs=[base_text, flavor, topk])
287
+ ex3.click(lambda: ("mezcal", "smoky", 3), outputs=[base_text, flavor, topk])
288
+
289
+ if __name__ == "__main__":
290
+ demo.launch()