Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| from transformers import LongformerTokenizer, pipeline | |
| from PIL import Image | |
| import pytesseract | |
| import cv2 | |
| import re | |
| import torch | |
| import matplotlib.pyplot as plt | |
| import math | |
| from typing import Dict, List, Any | |
| import numpy as np | |
| device = 0 if torch.cuda.is_available() else -1 | |
| model_id = "allenai/longformer-base-4096" | |
| tok = LongformerTokenizer.from_pretrained(model_id) | |
| emo_head = pipeline( | |
| "text-classification", | |
| model="j-hartmann/emotion-english-distilroberta-base", | |
| return_all_scores=True, | |
| device=device, | |
| ) | |
| time_regex = re.compile(r"(\d{1,2}[:]\d{2}\s*(AM|PM|am|pm)?)|(\d{1,2}[/]\d{1,2}[/]\d{2,4})") | |
| negative_keys = {"anger", "sadness", "fear", "disgust"} | |
| positive_keys = {"joy", "surprise"} | |
| def mask_names(names: List[str]) -> Dict[str, str]: | |
| return {n: f"User_{i+1}" for i, n in enumerate(names)} | |
| def ocr_image_path(path: str) -> str: | |
| img = Image.open(path).convert("RGB") | |
| return pytesseract.image_to_string(img) | |
| def ocr_video_path(path: str) -> str: | |
| cap = cv2.VideoCapture(path) | |
| texts = [] | |
| idx = 0 | |
| while True: | |
| ret, frame = cap.read() | |
| if not ret: | |
| break | |
| if idx % 25 == 0: | |
| rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) | |
| img = Image.fromarray(rgb) | |
| t = pytesseract.image_to_string(img) | |
| if t.strip(): | |
| texts.append(t) | |
| idx += 1 | |
| cap.release() | |
| return "\n".join(texts) | |
| def split_by_speaker(text: str, privacy: bool) -> Dict[str, str]: | |
| speakers: Dict[str, List[str]] = {} | |
| for raw in text.splitlines(): | |
| if ":" in raw: | |
| name, msg = raw.split(":", 1) | |
| name, msg = name.strip(), msg.strip() | |
| if msg: | |
| speakers.setdefault(name, []).append(msg) | |
| if not speakers: | |
| speakers["User"] = [text] | |
| if privacy: | |
| mapping = mask_names(list(speakers.keys())) | |
| return {mapping[k]: " ".join(v) for k, v in speakers.items()} | |
| return {k: " ".join(v) for k, v in speakers.items()} | |
| def chunk_text(text: str, max_tokens: int = 2048) -> List[str]: | |
| words = text.split() | |
| chunks: List[str] = [] | |
| temp: List[str] = [] | |
| for w in words: | |
| temp.append(w) | |
| enc = tok(" ".join(temp), truncation=True, max_length=max_tokens) | |
| if len(enc["input_ids"]) >= max_tokens: | |
| temp.pop() | |
| chunks.append(" ".join(temp)) | |
| temp = [w] | |
| if temp: | |
| chunks.append(" ".join(temp)) | |
| return chunks | |
| def emotion_scores(text: str) -> Dict[str, float]: | |
| res = emo_head(text)[0] | |
| return {x["label"]: float(x["score"]) for x in res} | |
| def emotions_over_chunks(chunks: List[str]) -> Dict[str, float]: | |
| if not chunks: | |
| return {} | |
| sums: Dict[str, float] = {} | |
| count = 0 | |
| for c in chunks: | |
| e = emotion_scores(c) | |
| for k, v in e.items(): | |
| sums[k] = sums.get(k, 0.0) + v | |
| count += 1 | |
| return {k: v / count for k, v in sums.items()} if count else {} | |
| def compute_risk(emotions: Dict[str, float]) -> float: | |
| neg = sum(emotions.get(k, 0.0) for k in negative_keys) | |
| strongest_neg = max((emotions.get(k, 0.0) for k in negative_keys), default=0.0) | |
| risk = 0.7 * neg + 0.3 * strongest_neg | |
| return max(0.0, min(1.0, risk)) | |
| def analyze(text_input, image_paths, video_paths, privacy_choice): | |
| collected: List[str] = [] | |
| if text_input and text_input.strip(): | |
| collected.append(text_input) | |
| if image_paths: | |
| for p in image_paths: | |
| t = ocr_image_path(p) | |
| if t.strip(): | |
| collected.append(t) | |
| if video_paths: | |
| for p in video_paths: | |
| t = ocr_video_path(p) | |
| if t.strip(): | |
| collected.append(t) | |
| if not collected: | |
| return None, None | |
| combined = "\n".join(collected) | |
| speakers = split_by_speaker(combined, privacy_choice == "ON") | |
| results: List[Dict[str, Any]] = [] | |
| for name, txt in speakers.items(): | |
| chunks = chunk_text(txt) | |
| emos = emotions_over_chunks(chunks) | |
| risk = compute_risk(emos) | |
| results.append( | |
| { | |
| "name": name, | |
| "risk": risk, | |
| "emotions": emos, | |
| } | |
| ) | |
| plt.style.use("default") | |
| fig, ax = plt.subplots(1, 2, figsize=(11, 4)) | |
| fig.patch.set_facecolor("white") | |
| names = [x["name"] for x in results] | |
| scores = [x["risk"] for x in results] | |
| ax[0].bar(names, scores, color="#DC2626", alpha=0.8) | |
| ax[0].set_ylim(0, 1) | |
| ax[0].set_title("Risk Levels", fontweight="bold", fontsize=12, color="#1F2937") | |
| ax[0].set_ylabel("Risk Score", fontsize=10, color="#4B5563") | |
| ax[0].set_facecolor("white") | |
| ax[0].grid(axis="y", alpha=0.2, linestyle="--") | |
| ax[0].spines["top"].set_visible(False) | |
| ax[0].spines["right"].set_visible(False) | |
| group_emo: Dict[str, float] = {} | |
| for r in results: | |
| for k, v in r["emotions"].items(): | |
| group_emo[k] = group_emo.get(k, 0.0) + v | |
| group_emo = {k: v / len(results) for k, v in group_emo.items()} | |
| colors = ["#10B981", "#3B82F6", "#8B5CF6", "#F59E0B", "#EC4899", "#06B6D4"] | |
| ax[1].bar(list(group_emo.keys()), list(group_emo.values()), color=colors[: len(group_emo)], alpha=0.8) | |
| ax[1].set_ylim(0, 1) | |
| ax[1].set_title("Group Emotion", fontweight="bold", fontsize=12, color="#1F2937") | |
| ax[1].set_ylabel("Intensity", fontsize=10, color="#4B5563") | |
| ax[1].set_facecolor("white") | |
| ax[1].grid(axis="y", alpha=0.2, linestyle="--") | |
| ax[1].spines["top"].set_visible(False) | |
| ax[1].spines["right"].set_visible(False) | |
| ax[1].tick_params(axis="x", rotation=45) | |
| plt.tight_layout() | |
| n = len(results) | |
| cols = min(3, n) | |
| rows = math.ceil(n / cols) | |
| fig2, ax2 = plt.subplots(rows, cols, figsize=(5 * cols, 3 * rows)) | |
| fig2.patch.set_facecolor("white") | |
| axlist = [ax2] if n == 1 else ax2.flatten() | |
| emotion_colors = { | |
| "anger": "#EF4444", | |
| "sadness": "#3B82F6", | |
| "fear": "#8B5CF6", | |
| "disgust": "#F59E0B", | |
| "joy": "#10B981", | |
| "surprise": "#EC4899", | |
| } | |
| for i, r in enumerate(results): | |
| axp = axlist[i] | |
| emotions = list(r["emotions"].keys()) | |
| values = list(r["emotions"].values()) | |
| bar_colors = [emotion_colors.get(e, "#6B7280") for e in emotions] | |
| axp.bar(emotions, values, color=bar_colors, alpha=0.8) | |
| axp.set_ylim(0, 1) | |
| axp.set_title(r["name"], fontweight="bold", fontsize=11, color="#1F2937") | |
| axp.set_ylabel("Intensity", fontsize=9, color="#4B5563") | |
| axp.set_facecolor("white") | |
| axp.grid(axis="y", alpha=0.2, linestyle="--") | |
| axp.spines["top"].set_visible(False) | |
| axp.spines["right"].set_visible(False) | |
| axp.tick_params(axis="x", rotation=45, labelsize=9) | |
| for j in range(len(axlist) - n): | |
| axlist[n + j].axis("off") | |
| fig2.tight_layout() | |
| return fig, fig2 | |
| custom_css = """ | |
| @import url('https://fonts.googleapis.com/css2?family=Inter:wght@400;500;600;700;800&display=swap'); | |
| * { | |
| font-family: 'Inter', sans-serif !important; | |
| } | |
| body { | |
| background: white !important; | |
| } | |
| .gradio-container { | |
| max-width: 1400px !important; | |
| margin: 0 auto !important; | |
| background: white !important; | |
| } | |
| .main { | |
| background: white !important; | |
| } | |
| .contain { | |
| background: white !important; | |
| } | |
| .gr-button-primary { | |
| background: linear-gradient(135deg, #667EEA 0%, #764BA2 100%) !important; | |
| box-shadow: 0 4px 15px rgba(102, 126, 234, 0.4) !important; | |
| color: white !important; | |
| border: none !important; | |
| } | |
| .gr-button-primary:hover { | |
| transform: translateY(-2px) !important; | |
| box-shadow: 0 6px 20px rgba(102, 126, 234, 0.5) !important; | |
| } | |
| .gr-box, .gr-form, .gr-panel { | |
| background: white !important; | |
| border: 1px solid #E5E7EB !important; | |
| border-radius: 12px !important; | |
| box-shadow: 0 1px 3px rgba(0,0,0,0.05) !important; | |
| } | |
| .gr-input, .gr-textarea { | |
| background: white !important; | |
| border: 1px solid #D1D5DB !important; | |
| border-radius: 8px !important; | |
| color: #1F2937 !important; | |
| } | |
| .gr-input:focus, .gr-textarea:focus { | |
| border-color: #667EEA !important; | |
| box-shadow: 0 0 0 3px rgba(102, 126, 234, 0.1) !important; | |
| } | |
| .gr-input::placeholder, .gr-textarea::placeholder { | |
| color: #9CA3AF !important; | |
| } | |
| label { | |
| color: #374151 !important; | |
| font-weight: 600 !important; | |
| font-size: 14px !important; | |
| } | |
| .tabs { | |
| background: white !important; | |
| border: 1px solid #E5E7EB !important; | |
| border-radius: 12px !important; | |
| } | |
| .tab-nav { | |
| background: #F9FAFB !important; | |
| border-bottom: 1px solid #E5E7EB !important; | |
| padding: 8px !important; | |
| } | |
| .tab-nav button { | |
| color: #6B7280 !important; | |
| font-weight: 600 !important; | |
| background: transparent !important; | |
| border-radius: 8px !important; | |
| padding: 10px 20px !important; | |
| } | |
| .tab-nav button.selected { | |
| background: white !important; | |
| color: #667EEA !important; | |
| border-bottom: 2px solid #667EEA !important; | |
| box-shadow: 0 1px 3px rgba(0,0,0,0.05) !important; | |
| } | |
| .gr-accordion { | |
| background: white !important; | |
| border: 1px solid #E5E7EB !important; | |
| border-radius: 10px !important; | |
| } | |
| .gr-file { | |
| background: white !important; | |
| border: 2px dashed #D1D5DB !important; | |
| border-radius: 10px !important; | |
| } | |
| .gr-file:hover { | |
| border-color: #667EEA !important; | |
| } | |
| .gr-radio { | |
| background: white !important; | |
| } | |
| .gr-radio label { | |
| background: white !important; | |
| border: 1px solid #D1D5DB !important; | |
| border-radius: 8px !important; | |
| padding: 10px 16px !important; | |
| color: #4B5563 !important; | |
| } | |
| .gr-radio label.selected { | |
| background: #EEF2FF !important; | |
| border-color: #667EEA !important; | |
| color: #667EEA !important; | |
| } | |
| .gr-plot { | |
| background: white !important; | |
| border-radius: 12px !important; | |
| padding: 16px !important; | |
| } | |
| footer { | |
| display: none !important; | |
| } | |
| """ | |
| with gr.Blocks(css=custom_css, theme=gr.themes.Soft(), title="Mental Health Chat Analyzer") as demo: | |
| gr.HTML( | |
| """ | |
| <div style='text-align: center; padding: 80px 20px 60px; background: white; border-bottom: 2px solid #E5E7EB; margin-bottom: 40px;'> | |
| <h1 style='color: #1F2937; font-size: 48px; font-weight: 800; margin: 0 0 20px 0;'> | |
| Mental Health Chat Analyzer | |
| </h1> | |
| <p style='color: #6B7280; font-size: 20px; max-width: 700px; margin: 0 auto 30px; line-height: 1.6;'> | |
| AI-powered emotional intelligence that analyzes conversations to provide emotion and risk insights | |
| </p> | |
| </div> | |
| <div style='max-width: 1000px; margin: 0 auto 60px; padding: 0 20px; background: white;'> | |
| <div style='display: grid; grid-template-columns: repeat(auto-fit, minmax(250px, 1fr)); gap: 30px;'> | |
| <div style='text-align: center; padding: 30px; background: white; border-radius: 12px; box-shadow: 0 2px 8px rgba(0,0,0,0.06); border: 1px solid #E5E7EB;'> | |
| <div style='font-size: 40px; margin-bottom: 15px;'>🧠</div> | |
| <h3 style='color: #1F2937; margin-bottom: 10px; font-weight: 700; font-size: 18px;'>AI Analysis</h3> | |
| <p style='color: #6B7280; font-size: 15px;'>Emotion and risk detection</p> | |
| </div> | |
| <div style='text-align: center; padding: 30px; background: white; border-radius: 12px; box-shadow: 0 2px 8px rgba(0,0,0,0.06); border: 1px solid #E5E7EB;'> | |
| <div style='font-size: 40px; margin-bottom: 15px;'>📊</div> | |
| <h3 style='color: #1F2937; margin-bottom: 10px; font-weight: 700; font-size: 18px;'>Visual Insights</h3> | |
| <p style='color: #6B7280; font-size: 15px;'>Risk and emotion charts</p> | |
| </div> | |
| <div style='text-align: center; padding: 30px; background: white; border-radius: 12px; box-shadow: 0 2px 8px rgba(0,0,0,0.06); border: 1px solid #E5E7EB;'> | |
| <div style='font-size: 40px; margin-bottom: 15px;'>🔒</div> | |
| <h3 style='color: #1F2937; margin-bottom: 10px; font-weight: 700; font-size: 18px;'>Privacy First</h3> | |
| <p style='color: #6B7280; font-size: 15px;'>Your data stays on-device</p> | |
| </div> | |
| </div> | |
| </div> | |
| """ | |
| ) | |
| gr.HTML( | |
| "<h2 style='text-align: center; font-size: 32px; margin-bottom: 30px; color: #1F2937; background: white; font-weight: 700;'>Start Your Analysis</h2>" | |
| ) | |
| with gr.Row(): | |
| with gr.Column(): | |
| privacy = gr.Radio( | |
| choices=["OFF", "ON"], | |
| value="OFF", | |
| label="Privacy Masking", | |
| info="Enable to anonymize participant names", | |
| ) | |
| text_in = gr.Textbox( | |
| label="Conversation Text", | |
| placeholder="Format: Name: message\n\nJohn: I'm stressed about work\nMary: Let's talk about it", | |
| lines=10, | |
| ) | |
| with gr.Accordion("Upload Files (Optional)", open=False): | |
| img_in = gr.File( | |
| label="Screenshots", | |
| file_types=["image"], | |
| file_count="multiple", | |
| type="filepath", | |
| ) | |
| vid_in = gr.File( | |
| label="Videos", | |
| file_count="multiple", | |
| type="filepath", | |
| ) | |
| analyze_btn = gr.Button("Analyze Conversation", variant="primary", size="lg") | |
| with gr.Tabs(): | |
| with gr.Tab("Risk Assessment"): | |
| plot1 = gr.Plot() | |
| with gr.Tab("Individual Profiles"): | |
| plot2 = gr.Plot() | |
| analyze_btn.click( | |
| analyze, | |
| inputs=[text_in, img_in, vid_in, privacy], | |
| outputs=[plot1, plot2], | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch() | |