| import sys |
| import os |
| import types |
| import logging |
| import re |
|
|
| |
| if 'audioop' not in sys.modules: |
| sys.modules['audioop'] = types.ModuleType('audioop') |
|
|
| import gradio as gr |
| import numpy as np |
| import matplotlib.pyplot as plt |
| import matplotlib |
| matplotlib.use('Agg') |
|
|
| logger = logging.getLogger(__name__) |
| logging.basicConfig(level=logging.INFO) |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| def _patched_get_transcript_from_audio(wav_filename, language="english"): |
| """CPU-safe replacement that creates word events from audio duration. |
| |
| When the audio was generated from known text (gTTS), the global |
| CURRENT_SCRIPT_TEXT will contain that text. Otherwise we create |
| a minimal placeholder so TRIBE's pipeline doesn't crash. |
| """ |
| import pandas as pd |
| import soundfile as sf |
| from pathlib import Path |
|
|
| wav_filename = Path(wav_filename) |
|
|
| |
| try: |
| info = sf.info(str(wav_filename)) |
| duration = info.duration |
| except Exception: |
| duration = 30.0 |
|
|
| |
| text = _CURRENT_SCRIPT_TEXT or "audio content placeholder" |
|
|
| |
| raw_words = text.split() |
| if not raw_words: |
| return pd.DataFrame(columns=["text", "start", "duration", "sequence_id", "sentence"]) |
|
|
| |
| sentences = re.split(r'(?<=[.!?])\s+', text) |
| sentences = [s.strip() for s in sentences if s.strip()] |
| if not sentences: |
| sentences = [text] |
|
|
| |
| word_duration = duration / len(raw_words) |
| words = [] |
| word_idx = 0 |
| for sent_idx, sentence in enumerate(sentences): |
| sent_words = sentence.split() |
| for w in sent_words: |
| if word_idx >= len(raw_words): |
| break |
| words.append({ |
| "text": w.replace('"', ''), |
| "start": word_idx * word_duration, |
| "duration": word_duration * 0.9, |
| "sequence_id": sent_idx, |
| "sentence": sentence.replace('"', ''), |
| }) |
| word_idx += 1 |
|
|
| return pd.DataFrame(words) |
|
|
|
|
| |
| _CURRENT_SCRIPT_TEXT = None |
|
|
|
|
| def apply_patches(): |
| """Patch TRIBE's ExtractWordsFromAudio to avoid whisperx/CUDA dependency.""" |
| try: |
| from tribev2.eventstransforms import ExtractWordsFromAudio |
| ExtractWordsFromAudio._get_transcript_from_audio = staticmethod( |
| _patched_get_transcript_from_audio |
| ) |
| logger.info("Patched ExtractWordsFromAudio (CPU-safe, no whisperx)") |
| except Exception as e: |
| logger.warning(f"Could not patch ExtractWordsFromAudio: {e}") |
|
|
| |
| apply_patches() |
|
|
| |
| |
| |
| model = None |
|
|
| def load_model(): |
| global model |
| if model is not None: |
| return "β
Already loaded!" |
| try: |
| apply_patches() |
| from tribev2 import TribeModel |
| model = TribeModel.from_pretrained("facebook/tribev2", cache_folder="/tmp/tribe_cache") |
| return "β
Model loaded!" |
| except Exception as e: |
| import traceback |
| traceback.print_exc() |
| return f"β Error loading model: {str(e)}" |
|
|
| |
| |
| |
| REGIONS = [ |
| ("Visual cortex", 0.00, 0.15, "#378ADD"), |
| ("Auditory cortex", 0.15, 0.30, "#D85A30"), |
| ("Language (Broca's area)", 0.30, 0.45, "#7F77DD"), |
| ("Prefrontal (attention)", 0.45, 0.62, "#1D9E75"), |
| ("Temporal (memory)", 0.62, 0.78, "#BA7517"), |
| ("Emotion (limbic)", 0.78, 1.00, "#D4537E"), |
| ] |
|
|
| def score_predictions(preds): |
| avg = np.mean(np.abs(preds), axis=0) |
| global_max = avg.max() + 1e-8 |
| half = len(avg) // 2 |
| scores = {} |
| for name, s, e, _ in REGIONS: |
| start, end = int(half * s), int(half * e) |
| scores[name] = round(float(np.mean(avg[start:end]) / global_max * 100), 1) |
| return scores, round(sum(scores.values()) / len(scores), 1) |
|
|
| def make_brain_plot(preds): |
| try: |
| from nilearn import plotting, datasets |
| avg = np.mean(np.abs(preds), axis=0) |
| avg_norm = (avg - avg.min()) / (avg.max() - avg.min() + 1e-8) |
| half = len(avg_norm) // 2 |
| fsaverage = datasets.fetch_surf_fsaverage("fsaverage5") |
| fig, axes = plt.subplots(1, 2, figsize=(14, 5), subplot_kw={"projection": "3d"}) |
| fig.patch.set_facecolor("#111111") |
| plotting.plot_surf_stat_map(fsaverage.infl_left, avg_norm[:half], hemi="left", |
| view="lateral", colorbar=True, cmap="hot", title="Left hemisphere", axes=axes[0], figure=fig) |
| plotting.plot_surf_stat_map(fsaverage.infl_right, avg_norm[half:], hemi="right", |
| view="lateral", colorbar=True, cmap="hot", title="Right hemisphere", axes=axes[1], figure=fig) |
| plt.tight_layout() |
| plt.savefig("/tmp/brain_map.png", dpi=130, bbox_inches="tight", facecolor="#111111") |
| plt.close() |
| return "/tmp/brain_map.png" |
| except Exception as e: |
| print(f"Brain plot error: {e}") |
| return None |
|
|
| def make_score_chart(scores, overall): |
| fig, ax = plt.subplots(figsize=(9, 4)) |
| fig.patch.set_facecolor("#1a1a1a") |
| ax.set_facecolor("#1a1a1a") |
| names = [r[0] for r in REGIONS] |
| colors = [r[3] for r in REGIONS] |
| vals = [scores.get(n, 0) for n in names] |
| bars = ax.barh(names, vals, color=colors, height=0.55) |
| ax.set_xlim(0, 100) |
| ax.axvline(70, color="#888", linestyle="--", linewidth=1, alpha=0.6) |
| ax.set_xlabel("Activation score", color="#ccc", fontsize=11) |
| ax.set_title(f"Brain region activation | Overall: {overall}/100", |
| color="white", fontsize=13, fontweight="bold", pad=12) |
| ax.tick_params(colors="#ccc") |
| for spine in ax.spines.values(): |
| spine.set_edgecolor("#333") |
| for bar, val in zip(bars, vals): |
| ax.text(bar.get_width() + 1, bar.get_y() + bar.get_height() / 2, |
| f"{val}", va="center", color="white", fontsize=10, fontweight="bold") |
| plt.tight_layout() |
| plt.savefig("/tmp/score_chart.png", dpi=130, bbox_inches="tight", facecolor="#1a1a1a") |
| plt.close() |
| return "/tmp/score_chart.png" |
|
|
| def generate_suggestions(scores, overall): |
| tips = [] |
| if scores.get("Prefrontal (attention)", 100) < 70: |
| tips.append("β Open with a bold question or surprising fact to boost attention") |
| if scores.get("Emotion (limbic)", 100) < 70: |
| tips.append("β Add emotional language β 'imagine', 'feel', personal stories") |
| if scores.get("Temporal (memory)", 100) < 70: |
| tips.append("β Include specific numbers or data points to improve memorability") |
| if scores.get("Visual cortex", 100) < 70: |
| tips.append("β Use more visual language β describe what viewers will 'see'") |
| if scores.get("Language (Broca's area)", 100) < 70: |
| tips.append("β Break long sentences into shorter, punchier ones") |
| if scores.get("Auditory cortex", 100) < 70: |
| tips.append("β Add rhythm and repetition β the brain responds to sound patterns") |
| if not tips: |
| tips.append("β Excellent! Consider adding a strong call-to-action at the end") |
| status = "π’ Strong" if overall >= 75 else "π‘ Good, needs polish" if overall >= 55 else "π΄ Needs work" |
| return f"**Overall: {overall}/100 β {status}**\n\n" + "\n".join(tips) |
|
|
| |
| |
| |
| def analyze(input_mode, script_text, audio_file, progress=gr.Progress()): |
| global _CURRENT_SCRIPT_TEXT |
|
|
| if input_mode == "Text" and (not script_text or not script_text.strip()): |
| return None, None, "β οΈ Please paste your script text first.", None |
| if input_mode == "Audio" and audio_file is None: |
| return None, None, "β οΈ Please upload an audio file first.", None |
|
|
| if model is None: |
| progress(0.1, desc="Loading TRIBE v2 model (first time ~5 mins)...") |
| msg = load_model() |
| if "Error" in msg: |
| return None, None, msg, None |
|
|
| try: |
| if input_mode == "Text": |
| progress(0.2, desc="Converting text to speech...") |
|
|
| from gtts import gTTS |
| from langdetect import detect |
|
|
| text = script_text.strip() |
| lang = detect(text) |
| audio_path = "/tmp/script_audio.mp3" |
| tts = gTTS(text=text, lang=lang) |
| tts.save(audio_path) |
|
|
| |
| |
| _CURRENT_SCRIPT_TEXT = text |
|
|
| progress(0.4, desc="Running TRIBE v2 on generated audio...") |
| df = model.get_events_dataframe(audio_path=audio_path) |
|
|
| else: |
| import shutil |
| progress(0.2, desc="Loading audio file...") |
| ext = os.path.splitext(audio_file)[1] or ".mp3" |
| audio_path = f"/tmp/input_audio{ext}" |
| shutil.copy(audio_file, audio_path) |
|
|
| |
| _CURRENT_SCRIPT_TEXT = None |
|
|
| progress(0.4, desc="Running TRIBE v2 on audio...") |
| df = model.get_events_dataframe(audio_path=audio_path) |
|
|
| progress(0.6, desc="Predicting brain response...") |
| preds, segments = model.predict(events=df) |
|
|
| progress(0.75, desc="Scoring regions...") |
| scores, overall = score_predictions(preds) |
|
|
| progress(0.85, desc="Rendering maps...") |
| brain_img = make_brain_plot(preds) |
| score_img = make_score_chart(scores, overall) |
| suggestions = generate_suggestions(scores, overall) |
|
|
| np.save("/tmp/brain_predictions.npy", preds) |
| progress(1.0, desc="Done!") |
| return brain_img, score_img, suggestions, "/tmp/brain_predictions.npy" |
|
|
| except Exception as e: |
| import traceback |
| full_error = traceback.format_exc() |
| print(full_error) |
| return None, None, f"β Error:\n{str(e)}\n\nFull traceback:\n{full_error}", None |
| finally: |
| _CURRENT_SCRIPT_TEXT = None |
|
|
| |
| |
| |
| css = "#title{text-align:center} #subtitle{text-align:center;color:#888;font-size:14px}" |
|
|
| with gr.Blocks(theme=gr.themes.Soft(primary_hue="indigo"), css=css) as demo: |
| gr.Markdown("# π§ Script Brain Optimizer", elem_id="title") |
| gr.Markdown("Analyze your script or audio β real fMRI predictions via **TRIBE v2** β iterate", elem_id="subtitle") |
|
|
| with gr.Row(): |
| with gr.Column(scale=1): |
| input_mode = gr.Radio( |
| choices=["Text", "Audio"], value="Text", |
| label="Input type", |
| info="Text: paste your script | Audio: upload MP3/WAV" |
| ) |
| script_input = gr.Textbox( |
| label="Your script", |
| placeholder="Paste your content script here...", |
| lines=10, max_lines=20, visible=True |
| ) |
| audio_input = gr.Audio( |
| label="Upload audio file (MP3, WAV, M4A, FLAC)", |
| type="filepath", sources=["upload"], visible=False |
| ) |
| with gr.Row(): |
| clear_btn = gr.Button("Clear", variant="secondary", scale=1) |
| analyze_btn = gr.Button("π§ Analyze", variant="primary", scale=3) |
| suggestions_out = gr.Markdown(value="*Add your content and click Analyze...*") |
| download_out = gr.File(label="Download predictions (.npy)") |
|
|
| with gr.Column(scale=2): |
| brain_img_out = gr.Image(label="Brain activation map", height=320) |
| score_img_out = gr.Image(label="Region scores", height=280) |
|
|
| def toggle_mode(mode): |
| return gr.update(visible=mode=="Text"), gr.update(visible=mode=="Audio") |
|
|
| input_mode.change(fn=toggle_mode, inputs=[input_mode], |
| outputs=[script_input, audio_input]) |
|
|
| analyze_btn.click(fn=analyze, inputs=[input_mode, script_input, audio_input], |
| outputs=[brain_img_out, score_img_out, suggestions_out, download_out]) |
|
|
| clear_btn.click( |
| fn=lambda: ("", None, None, None, "*Add your content and click Analyze...*", None), |
| outputs=[script_input, audio_input, brain_img_out, score_img_out, suggestions_out, download_out] |
| ) |
|
|
| gr.Markdown("---\n*Powered by [TRIBE v2](https://github.com/facebookresearch/tribev2) by Meta FAIR*") |
|
|
| if __name__ == "__main__": |
| demo.launch(server_name="0.0.0.0", server_port=7860) |
|
|