Spaces:
Sleeping
Sleeping
| import os, re | |
| import numpy as np | |
| import pandas as pd | |
| import gradio as gr | |
| import faiss | |
| import torch | |
| from typing import List | |
| from sentence_transformers import SentenceTransformer, CrossEncoder | |
| from transformers import AutoTokenizer, AutoModelForSeq2SeqLM | |
| # ---- Config ---- | |
| FLAN_PRIMARY = os.getenv("FLAN_PRIMARY", "google/flan-t5-large") | |
| EMBED_NAME = "sentence-transformers/all-mpnet-base-v2" | |
| RERANK_NAME = "cross-encoder/stsb-roberta-base" | |
| NUM_SLOGAN_SAMPLES = int(os.getenv("NUM_SLOGAN_SAMPLES", "16")) | |
| DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| ASSETS_DIR = "assets" | |
| # ---- Lazy models ---- | |
| _GEN_TOK = None | |
| _GEN_MODEL = None | |
| _EMBED_MODEL = None | |
| _RERANKER = None | |
| def _ensure_models(): | |
| global _GEN_TOK, _GEN_MODEL, _EMBED_MODEL, _RERANKER | |
| if _EMBED_MODEL is None: | |
| _EMBED_MODEL = SentenceTransformer(EMBED_NAME) | |
| if _RERANKER is None: | |
| _RERANKER = CrossEncoder(RERANK_NAME) | |
| if _GEN_MODEL is None: | |
| tok = AutoTokenizer.from_pretrained(FLAN_PRIMARY) | |
| mdl = AutoModelForSeq2SeqLM.from_pretrained(FLAN_PRIMARY) | |
| _GEN_TOK, _GEN_MODEL = tok, mdl.to(DEVICE) | |
| print(f"[INFO] Loaded generator: {FLAN_PRIMARY}") | |
| # ---- Data & PRE-BUILT FAISS from assets folder ---- | |
| _DATA_DF = None | |
| _INDEX = None | |
| _EMBEDDINGS = None | |
| def _ensure_index(): | |
| global _DATA_DF, _INDEX, _EMBEDDINGS | |
| if _INDEX is not None: | |
| return | |
| # Load assets from the assets directory | |
| try: | |
| data_path = os.path.join(ASSETS_DIR, "data.parquet") | |
| index_path = os.path.join(ASSETS_DIR, "faiss.index") | |
| emb_path = os.path.join(ASSETS_DIR, "embeddings.npy") | |
| _DATA_DF = pd.read_parquet(data_path) | |
| _INDEX = faiss.read_index(index_path) | |
| _EMBEDDINGS = np.load(emb_path) | |
| print(f"[INFO] Loaded pre-built FAISS index. rows={len(_DATA_DF)}, dim={_INDEX.d}") | |
| except FileNotFoundError: | |
| print("[ERROR] Pre-built assets not found. The space may fail to run.") | |
| print("[INFO] Falling back to building a tiny demo index.") | |
| _DATA_DF = pd.DataFrame({ | |
| "name": ["HowDidIDo", "Museotainment", "Movitr"], | |
| "tagline": ["Online evaluation platform", "PacMan & Louvre meet", "Crowdsourced video translation"], | |
| "description": [ | |
| "Public speaking, Presentation skills and interview practice", | |
| "Interactive AR museum tours", | |
| "Video translation with voice and subtitles" | |
| ] | |
| }) | |
| _ensure_models() | |
| vecs = _EMBED_MODEL.encode(_DATA_DF["description"].astype(str).tolist(), normalize_embeddings=True).astype(np.float32) | |
| _INDEX = faiss.IndexFlatIP(vecs.shape[1]) | |
| _INDEX.add(vecs) | |
| def recommend(query_text: str, top_k: int = 3) -> pd.DataFrame: | |
| _ensure_index() | |
| _ensure_models() | |
| q_vec = _EMBED_MODEL.encode([query_text], normalize_embeddings=True).astype("float32") | |
| scores, idxs = _INDEX.search(q_vec, top_k) | |
| out = _DATA_DF.iloc[idxs[0]].copy() | |
| out["score"] = scores[0] | |
| return out[["name","tagline","description","score"]] | |
| # ---- Refined v2 slogan generator (unchanged logic) ---- | |
| BLOCK_PATTERNS = [ | |
| r"^[A-Z][a-z]+ [A-Z][a-z]+ (Platform|Solution|System|Application|Marketplace)$", | |
| r"^[A-Z][a-z]+ [A-Z][a-z]+$", | |
| r"^[A-Z][a-z]+$", | |
| ] | |
| HARD_BLOCK_WORDS = { | |
| "platform","solution","system","application","marketplace", | |
| "ai-powered","ai powered","empower","empowering", | |
| "artificial intelligence","machine learning","augmented reality","virtual reality", | |
| } | |
| GENERIC_WORDS = {"app","assistant","smart","ai","ml","ar","vr","decentralized","blockchain"} | |
| MARKETING_VERBS = {"build","grow","simplify","discover","create","connect","transform","unlock","boost","learn","move","clarify"} | |
| BENEFIT_WORDS = {"faster","smarter","easier","better","safer","clearer","stronger","together","confidently","simply","instantly"} | |
| GOOD_SLOGANS_TO_AVOID_DUP = { | |
| "smarter care, faster decisions", | |
| "checkout built for small brands", | |
| "less guessing. more healing.", | |
| "built to grow with your cart.", | |
| "stand tall. feel better.", | |
| "train your brain to win.", | |
| "your body. your algorithm.", | |
| "play smarter. grow brighter.", | |
| "style that thinks with you." | |
| } | |
| def _tokens(s: str) -> List[str]: | |
| return re.findall(r"[a-z0-9]{3,}", s.lower()) | |
| def _jaccard(a: List[str], b: List[str]) -> float: | |
| A, B = set(a), set(b) | |
| return 0.0 if not A or not B else len(A & B) / len(A | B) | |
| def _titlecase_soft(s: str) -> str: | |
| out = [] | |
| for w in s.split(): | |
| out.append(w if w.isupper() else w.capitalize()) | |
| return " ".join(out) | |
| def _is_blocked_slogan(s: str) -> bool: | |
| if not s: return True | |
| s_strip = s.strip() | |
| for pat in BLOCK_PATTERNS: | |
| if re.match(pat, s_strip): | |
| return True | |
| s_low = s_strip.lower() | |
| for w in HARD_BLOCK_WORDS: | |
| if w in s_low: | |
| return True | |
| if s_low in GOOD_SLOGANS_TO_AVOID_DUP: | |
| return True | |
| return False | |
| def _generic_penalty(s: str) -> float: | |
| hits = sum(1 for w in GENERIC_WORDS if w in s.lower()) | |
| return min(1.0, 0.25 * hits) | |
| def _for_penalty(s: str) -> float: | |
| return 0.3 if re.search(r"\bfor\b", s.lower()) else 0.0 | |
| def _neighbor_context(neighbors_df: pd.DataFrame) -> str: | |
| if neighbors_df is None or neighbors_df.empty: | |
| return "" | |
| examples = [] | |
| for _, row in neighbors_df.head(3).iterrows(): | |
| tg = str(row.get("tagline", "")).strip() | |
| if 5 <= len(tg) <= 70: | |
| examples.append(f"- {tg}") | |
| return "\n".join(examples) | |
| def _copies_neighbor(s: str, neighbors_df: pd.DataFrame) -> bool: | |
| if neighbors_df is None or neighbors_df.empty: | |
| return False | |
| s_low = s.lower() | |
| s_toks = _tokens(s_low) | |
| for _, row in neighbors_df.iterrows(): | |
| t = str(row.get("tagline", "")).strip() | |
| if not t: | |
| continue | |
| t_low = t.lower() | |
| if s_low == t_low: | |
| return True | |
| if _jaccard(s_toks, _tokens(t_low)) >= 0.7: | |
| return True | |
| try: | |
| _ensure_models() | |
| s_vec = _EMBED_MODEL.encode([s])[0]; s_vec = s_vec / np.linalg.norm(s_vec) | |
| for _, row in neighbors_df.head(3).iterrows(): | |
| t = str(row.get("tagline", "")).strip() | |
| if not t: continue | |
| t_vec = _EMBED_MODEL.encode([t])[0]; t_vec = t_vec / np.linalg.norm(t_vec) | |
| if float(np.dot(s_vec, t_vec)) >= 0.85: | |
| return True | |
| except Exception: | |
| pass | |
| return False | |
| def _clean_slogan(text: str, max_words: int = 8) -> str: | |
| text = text.strip().split("\n")[0] | |
| text = re.sub(r"[\"“”‘’]", "", text) | |
| text = re.sub(r"\s+", " ", text).strip() | |
| text = re.sub(r"^\W+|\W+$", "", text) | |
| words = text.split() | |
| if len(words) > max_words: | |
| text = " ".join(words[:max_words]) | |
| return text | |
| def _score_candidates(query: str, cands: List[str], neighbors_df: pd.DataFrame) -> List[tuple]: | |
| if not cands: | |
| return [] | |
| _ensure_models() | |
| ce_scores = np.asarray(_RERANKER.predict([(query, s) for s in cands]), dtype=np.float32) / 5.0 | |
| q_toks = _tokens(query) | |
| results = [] | |
| neighbor_vecs = [] | |
| if neighbors_df is not None and not neighbors_df.empty: | |
| _ensure_models() | |
| for _, row in neighbors_df.head(3).iterrows(): | |
| t = str(row.get("tagline","")).strip() | |
| if t: | |
| v = _EMBED_MODEL.encode([t])[0] | |
| neighbor_vecs.append(v / np.linalg.norm(v)) | |
| for i, s in enumerate(cands): | |
| words = s.split() | |
| brevity = 1.0 - min(1.0, abs(len(words) - 5) / 5.0) | |
| wl = set(w.lower() for w in words) | |
| m_hits = len(wl & MARKETING_VERBS) | |
| b_hits = len(wl & BENEFIT_WORDS) | |
| marketing = min(1.0, 0.2*m_hits + 0.2*b_hits) | |
| g_pen = _generic_penalty(s) | |
| f_pen = _for_penalty(s) | |
| n_pen = 0.0 | |
| if neighbor_vecs: | |
| try: | |
| _ensure_models() | |
| s_vec = _EMBED_MODEL.encode([s])[0]; s_vec = s_vec / np.linalg.norm(s_vec) | |
| sim_max = max(float(np.dot(s_vec, nv)) for nv in neighbor_vecs) if neighbor_vecs else 0.0 | |
| n_pen = sim_max | |
| except Exception: | |
| n_pen = 0.0 | |
| overlap = _jaccard(q_toks, _tokens(s)) | |
| anti_copy = 1.0 - overlap | |
| score = ( | |
| 0.55*float(ce_scores[i]) + | |
| 0.20*brevity + | |
| 0.15*marketing + | |
| 0.03*anti_copy - | |
| 0.07*g_pen - | |
| 0.03*f_pen - | |
| 0.10*n_pen | |
| ) | |
| results.append((s, float(score))) | |
| return results | |
| def generate_slogan(query_text: str, neighbors_df: pd.DataFrame = None, n_samples: int = NUM_SLOGAN_SAMPLES) -> str: | |
| _ensure_models() | |
| ctx = _neighbor_context(neighbors_df) | |
| prompt = ( | |
| "You are a creative brand copywriter. Write short, original, memorable startup slogans (max 8 words).\n" | |
| "Forbidden words: app, assistant, platform, solution, system, marketplace, AI, machine learning, augmented reality, virtual reality, decentralized, empower.\n" | |
| "Focus on clear benefits and vivid verbs. Do not copy the description. Return ONLY a list, one slogan per line.\n\n" | |
| "Good Examples:\n" | |
| "Description: AI assistant for doctors to prioritize patient cases\n" | |
| "Slogan: Less Guessing. More Healing.\n\n" | |
| "Description: Payments for small online stores\n" | |
| "Slogan: Built to Grow with Your Cart.\n\n" | |
| "Description: Neurotech headset to boost focus\n" | |
| "Slogan: Train Your Brain to Win.\n\n" | |
| "Description: Interior design suggestions with AI\n" | |
| "Slogan: Style That Thinks With You.\n\n" | |
| "Bad Examples (avoid these): Innovative AI Platform / Smart App for Everyone / Empowering Small Businesses\n\n" | |
| ) | |
| if ctx: | |
| prompt += f"Similar taglines (style only):\n{ctx}\n\n" | |
| prompt += f"Description: {query_text}\nSlogans:" | |
| input_ids = _GEN_TOK(prompt, return_tensors="pt").input_ids.to(DEVICE) | |
| outputs = _GEN_MODEL.generate( | |
| input_ids, | |
| max_new_tokens=24, | |
| do_sample=True, | |
| top_k=60, | |
| top_p=0.92, | |
| temperature=1.2, | |
| num_return_sequences=n_samples, | |
| repetition_penalty=1.08 | |
| ) | |
| raw_cands = [_GEN_TOK.decode(o, skip_special_tokens=True) for o in outputs] | |
| cand_set = set() | |
| for txt in raw_cands: | |
| for line in txt.split("\n"): | |
| s = _clean_slogan(line) | |
| if not s: | |
| continue | |
| if len(s.split()) < 2 or len(s.split()) > 8: | |
| continue | |
| if _is_blocked_slogan(s): | |
| continue | |
| if _copies_neighbor(s, neighbors_df): | |
| continue | |
| cand_set.add(_titlecase_soft(s)) | |
| if not cand_set: | |
| return _clean_slogan(_GEN_TOK.decode(outputs[0], skip_special_tokens=True)) | |
| scored = _score_candidates(query_text, sorted(cand_set), neighbors_df) | |
| if not scored: | |
| return _clean_slogan(_GEN_TOK.decode(outputs[0], skip_special_tokens=True)) | |
| scored.sort(key=lambda x: x[1], reverse=True) | |
| return scored[0][0] | |
| # ---- Gradio UI ---- | |
| EXAMPLES = [ | |
| "AI coach for improving public speaking skills", | |
| "Augmented reality app for interactive museum tours", | |
| "Voice-controlled task manager for remote teams", | |
| "Machine learning system for predicting crop yields", | |
| "Platform for AI-assisted interior design suggestions", | |
| ] | |
| def pipeline(user_input: str): | |
| recs = recommend(user_input, top_k=3) | |
| slogan = generate_slogan(user_input, neighbors_df=recs, n_samples=NUM_SLOGAN_SAMPLES) | |
| recs = recs.reset_index(drop=True) | |
| recs.loc[len(recs)] = {"name":"Synthetic Example","tagline":slogan,"description":user_input,"score":np.nan} | |
| return recs[["name","tagline","description","score"]], slogan | |
| with gr.Blocks(title="SloganAI — Recommendations + Slogan Generator") as demo: | |
| gr.Markdown("## SloganAI — Top-3 Recommendations + A High-Quality Generated Slogan") | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| inp = gr.Textbox(label="Enter a startup description", lines=3, placeholder="e.g., AI coach for improving public speaking skills") | |
| gr.Examples(EXAMPLES, inputs=inp, label="One-click examples") | |
| btn = gr.Button("Submit", variant="primary") | |
| with gr.Column(scale=2): | |
| out_df = gr.Dataframe(headers=["Name","Tagline","Description","Score"], label="Top 3 + Generated") | |
| out_sg = gr.Textbox(label="Generated Slogan", interactive=False) | |
| btn.click(fn=pipeline, inputs=inp, outputs=[out_df, out_sg]) | |
| if __name__ == "__main__": | |
| _ensure_models() | |
| _ensure_index() | |
| demo.queue().launch() | |