from __future__ import annotations # === Writable config dirs (must come right after future import) === import os os.environ.setdefault("APP_DATA_DIR", "/data/app_data" if os.access("/data", os.W_OK) else "/tmp/app_data") os.environ.setdefault("MPLCONFIGDIR", os.path.join(os.environ["APP_DATA_DIR"], "mplconfig")) os.makedirs(os.environ["MPLCONFIGDIR"], exist_ok=True) # ================================================================= import json import uuid from datetime import datetime from typing import Dict import gradio as gr import pandas as pd from app.storage import ( init_db, insert_variant, upsert_campaign, get_variant, get_metrics, get_campaign_value_per_conversion, log_event ) from app.bandit import ThompsonBandit from app.forecast import SeasonalityModel from app.compliance import rule_based_check, llm_check_and_fix from app.openai_client import openai_chat_json # 初期化 init_db() _seasonality_cache: Dict[str, SeasonalityModel] = {} # 固定カラム(常にこの形で返す) GENERATE_COLUMNS = ["variant_id", "status", "rejection_reason", "text"] REPORT_COLUMNS = ["variant_id","impressions","clicks","conversions","ctr","cvr","expected_value"] # JSONモード前提の厳格プロンプト GEN_SYSTEM = """ あなたは日本語広告コピーのプロフェッショナルコピーライターです。 出力は**次のJSONオブジェクトのみ**で厳密に返してください。余計な文章・説明・前置きは禁止です。 形式: { "variants": [ {"headline": "全角15-25字程度", "body": "全角40-90字程度"}, ... ] } ルール: - 医薬効能の断定、100%、永久、即効、根拠のない数値などの誇大表現は禁止 - CTAは自然に - 日本語で、句読点や記号は自然に """ GEN_USER_TEMPLATE = """ ブランド: {brand} 商品/サービス: {product} 想定ターゲット: {target} トーン: {tone} 制約: {constraints} 生成本数: {k} 要件: - "variants" 配列の要素数は **ちょうど {k}** 件にしてください - 各要素は {{"headline": "...", "body": "..."}} のみ """ def _seasonal(campaign_id: str) -> SeasonalityModel: if campaign_id not in _seasonality_cache: m = SeasonalityModel(campaign_id) try: m.fit() except Exception: # 失敗してもフォールバックあり pass _seasonality_cache[campaign_id] = m return _seasonality_cache[campaign_id] def _safe_get_variants(data, k: int): """LLM応答から variants 配列を安全に取り出して正規化。失敗時は None を返す。""" items = [] if isinstance(data, dict) and isinstance(data.get("variants"), list): items = data["variants"] elif isinstance(data, list): items = data if not items or not all(isinstance(x, dict) for x in items): return None out = [] for it in items[:k]: out.append({ "headline": str(it.get("headline", "")).strip(), "body": str(it.get("body", "")).strip(), }) return out def _local_variants(brand: str, product: str, k: int): """LLMが不調でも UI を止めないための最終フォールバック生成(簡易・無害表現)。""" base_head = [ "使いやすさで選ばれています", "日々の習慣をシンプルに", "はじめてでも安心", "続けやすいサポートを", "いま必要な機能だけを" ] base_body = [ "{brand}の「{product}」。生活になじむ設計で、今日からムリなく始められます。まずは詳細をご覧ください。", "毎日を少しラクに。{brand}の{product}が、あなたの習慣づくりを後押しします。今すぐチェック。", "難しい操作は不要。{brand}の{product}なら、使い始めから自然に続けられます。詳しくはサイトへ。", "必要な情報をひと目で。{brand}の{product}で、日々の管理をシンプルに。詳細を見る。", "続けやすさを重視。{brand}の{product}で、小さな一歩から。" ] out = [] for i in range(k): hi = base_head[i % len(base_head)] bo = base_body[i % len(base_body)].format(brand=brand, product=product) out.append({"headline": hi, "body": bo}) return out async def ui_generate(campaign_id: str, brand: str, product: str, target: str, tone: str, k_variants: int, ng_words: str, value_per_conversion: float): k_variants = int(k_variants) # Slider の float を明示的に int 化 constraints = {"ng_words": [w.strip() for w in ng_words.splitlines() if w.strip()]} if ng_words else {} upsert_campaign( campaign_id, brand, product, target, tone, "ja", constraints, value_per_conversion ) user = GEN_USER_TEMPLATE.format( brand=brand, product=product, target=target, tone=tone, constraints=json.dumps(constraints, ensure_ascii=False), k=k_variants, ) # まずは通常プロンプトで JSON モード呼び出し items = None try: data = await openai_chat_json( [ {"role": "system", "content": GEN_SYSTEM}, {"role": "user", "content": user}, ], temperature=0.2, max_tokens=1200, ) items = _safe_get_variants(data, k_variants) except Exception: items = None # 失敗/空のときは、温度をさらに下げて再試行(より厳格に) if not items: try: retry_user = user + "\n\n注意: 'variants' は必ず指定件数、各要素は {\"headline\":\"...\",\"body\":\"...\"} のみ。" data = await openai_chat_json( [ {"role": "system", "content": GEN_SYSTEM}, {"role": "user", "content": retry_user}, ], temperature=0.1, max_tokens=1000, ) items = _safe_get_variants(data, k_variants) except Exception: items = None # それでも無理ならローカル生成(UIを止めない) if not items: items = _local_variants(brand, product, k_variants) rows = [] for it in items[:k_variants]: headline = it["headline"] body = it["body"] text = f"{headline}\n{body}".strip() vid = str(uuid.uuid4())[:8] ok_rule, bads = rule_based_check(text, (constraints or {}).get("ng_words")) rejection_reason = None status = "approved" if not ok_rule: ok_llm, reasons, fixed = llm_check_and_fix(text) if ok_llm: text = fixed or text else: status = "rejected" rejection_reason = "; ".join(bads + reasons) else: ok_llm, reasons, fixed = llm_check_and_fix(text) if not ok_llm: text = fixed or text insert_variant(campaign_id, vid, text, status, rejection_reason) rows.append({ "variant_id": vid, "status": status, "rejection_reason": rejection_reason or "", "text": text, }) # 常に固定カラムで返す(空でもカラムを持つDataFrame) df = pd.DataFrame(rows, columns=GENERATE_COLUMNS) return df def ui_serve(campaign_id: str, hour: int, segment: str): ctx = {"hour": int(hour), "segment": (segment or "").strip() or None} m = _seasonal(campaign_id) bandit = ThompsonBandit(campaign_id) vid, _ = bandit.sample_arm(ctx, m.expected_ctr) if not vid: raise gr.Error("配信可能なバリアントがありません。まずは Generate してください。") row = get_variant(campaign_id, vid) if not row: raise gr.Error("バリアントが見つかりません。") # impression 記録 log_event(campaign_id, vid, "impression", datetime.utcnow().isoformat(), None) ThompsonBandit.update_with_event(campaign_id, vid, "impression") return vid, row["text"] def ui_feedback(campaign_id: str, variant_id: str, event_type: str): if not variant_id: raise gr.Error("先に Serve してください。") log_event(campaign_id, variant_id, event_type, datetime.utcnow().isoformat(), None) ThompsonBandit.update_with_event(campaign_id, variant_id, event_type) return f"{event_type} を記録しました。" def ui_report(campaign_id: str): mets = get_metrics(campaign_id) vpc = get_campaign_value_per_conversion(campaign_id) rows = [] for r in mets: imp = int(r["impressions"]); clk = int(r["clicks"]); conv = int(r["conversions"]) ctr = (clk / imp) if imp > 0 else 0.0 cvr = (conv / clk) if clk > 0 else 0.0 ev = ctr * cvr * vpc rows.append({ "variant_id": r["variant_id"], "impressions": imp, "clicks": clk, "conversions": conv, "ctr": round(ctr, 4), "cvr": round(cvr, 4), "expected_value": round(ev, 6), }) # 常に固定カラムで返す(空でもカラムを持つDataFrame) df = pd.DataFrame(rows, columns=REPORT_COLUMNS) return df def ui_check(text: str): ok_rule, bads = rule_based_check(text, []) ok_llm, reasons, fixed = llm_check_and_fix(text) status = "pass" if (ok_rule and ok_llm) else "needs_fix" fixed_text = fixed or (text if status == "pass" else "") reasons_joined = "; ".join(bads + reasons) return status, reasons_joined, fixed_text with gr.Blocks(title="AdCopy MAB Optimizer", fill_height=True) as demo: gr.Markdown(""" # AdCopy MAB Optimizer(HF UI) **広告コピー自動生成 → Thompson Sampling(CTR×CVR) → レポート** を、Hugging Face Spaces 上で完結。 - LLM: OpenAI (`OPENAI_API_KEY` を Space Secrets に設定) - DB: SQLite(`/data/app_data/data.db` など、書き込み可能ディレクトリ) - 季節性: Prophet/NeuralProphet(なければ簡易ヒューリスティック) """) with gr.Tab("1) Generate"): with gr.Row(): campaign_id = gr.Textbox(label="campaign_id", value="cmp-demo", scale=1) k_variants = gr.Slider(1, 10, value=5, step=1, label="生成本数") value_per_conv = gr.Number(value=5000, label="value_per_conversion") brand = gr.Textbox(label="ブランド", value="SFM") product = gr.Textbox(label="商品/サービス", value="HbA1c測定アプリ") target = gr.Textbox(label="ターゲット", value="30-50代の健康意識が高い層") tone = gr.Textbox(label="トーン", value="エビデンス重視で安心感") ng_words = gr.Textbox(label="NGワード(改行区切り)", value="治る\n奇跡") btn_gen = gr.Button("広告案を生成&審査&保存") table_gen = gr.Dataframe(headers=GENERATE_COLUMNS, interactive=False) btn_gen.click(ui_generate, [campaign_id, brand, product, target, tone, k_variants, ng_words, value_per_conv], [table_gen]) with gr.Tab("2) Serve & Feedback"): with gr.Row(): campaign_id2 = gr.Textbox(label="campaign_id", value="cmp-demo", scale=1) hour = gr.Slider(0, 23, value=20, step=1, label="hour") segment = gr.Textbox(label="segment (任意)") btn_serve = gr.Button("Serve Ad(impressionを記録)") served_vid = gr.Textbox(label="served variant_id", interactive=False) served_text = gr.Textbox(label="served text", lines=6, interactive=False) btn_serve.click(ui_serve, [campaign_id2, hour, segment], [served_vid, served_text]) with gr.Row(): btn_click = gr.Button("Clickを記録") btn_conv = gr.Button("Conversionを記録") msg = gr.Markdown() btn_click.click(lambda cid, vid: ui_feedback(cid, vid, "click"), [campaign_id2, served_vid], [msg]) btn_conv.click(lambda cid, vid: ui_feedback(cid, vid, "conversion"), [campaign_id2, served_vid], [msg]) with gr.Tab("3) Report"): campaign_id3 = gr.Textbox(label="campaign_id", value="cmp-demo") btn_rep = gr.Button("更新") table_rep = gr.Dataframe(headers=REPORT_COLUMNS, interactive=False) btn_rep.click(ui_report, [campaign_id3], [table_rep]) with gr.Tab("4) Compliance Check"): cand = gr.Textbox(label="チェックする文面", lines=5) btn_chk = gr.Button("判定") status = gr.Textbox(label="status") reasons = gr.Textbox(label="reasons") fixed = gr.Textbox(label="fixed (修正案)", lines=5) btn_chk.click(ui_check, [cand], [status, reasons, fixed]) if __name__ == "__main__": demo.queue().launch(server_name="0.0.0.0", server_port=7860)