|
import gradio as gr |
|
import plotly.express as px |
|
import pandas as pd |
|
import logging |
|
import whisper |
|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
import numpy as np |
|
import pandas as pd |
|
from torch.nn.functional import silu |
|
from torch.nn.functional import softplus |
|
from einops import rearrange, repeat, einsum |
|
from transformers import AutoTokenizer, AutoModel |
|
from torch import Tensor |
|
from einops import rearrange |
|
|
|
from model import Mamba |
|
|
|
logging.basicConfig(level=logging.INFO) |
|
|
|
def plotly_plot_text(text): |
|
data = pd.DataFrame() |
|
data['Emotion'] = ['π anger', 'π€’ disgust', 'π¨ fear', 'π joy/happiness', 'π neutral', 'π’ sadness', 'π² surprise/enthusiasm'] |
|
data['Probability'] = model.predict_proba([text])[0].tolist() |
|
p = px.bar(data, x='Emotion', y='Probability', color="Probability") |
|
return ( |
|
p, |
|
f"π£οΈ Transcription:\n{text}", |
|
f"## π Dominant Emotion: {data['Emotion'].values[np.argmax(np.array(data['Probability']))]}" |
|
) |
|
|
|
def transcribe_audio(audio_path): |
|
whisper_model = whisper.load_model("base") |
|
try: |
|
result = whisper_model.transcribe(audio_path, fp16=False) |
|
return result.get('text', '') |
|
except Exception as e: |
|
logging.error(f"Transcription failed: {e}") |
|
return "" |
|
|
|
def plotly_plot_audio(audio_path): |
|
data = pd.DataFrame() |
|
data['Emotion'] = ['π anger', 'π€’ disgust', 'π¨ fear', 'π joy/happiness', 'π neutral', 'π’ sadness', 'π² surprise/enthusiasm'] |
|
try: |
|
text = transcribe_audio(audio_path) |
|
data['Probability'] = model.predict_proba([text])[0].tolist() if text.strip() else [0.0] * data.shape[0] |
|
p = px.bar(data, x='Emotion', y='Probability', color="Probability") |
|
return ( |
|
p, |
|
f"## βοΈ Dominant Emotion: {data['Emotion'].values[np.argmax(np.array(data['Probability']))]}" |
|
) |
|
|
|
except Exception as e: |
|
logging.error(f"Processing failed: {e}") |
|
data['Probability'] = [0] * data.shape[0] |
|
p = px.bar(data, x='Emotion', y='Probability', color="Probability") |
|
return ( |
|
p, |
|
"β Error processing audio", |
|
"β οΈ Processing Error" |
|
) |
|
|
|
def plotly_plot_audio(audio_path): |
|
data = pd.DataFrame() |
|
data['Emotion'] = ['π anger', 'π€’ disgust', 'π¨ fear', 'π joy/happiness', 'π neutral', 'π’ sadness', 'π² surprise/enthusiasm'] |
|
try: |
|
text = transcribe_audio(audio_path) |
|
data['Probability'] = model.predict_proba([text])[0].tolist() if text.strip() else [0.0] * data.shape[0] |
|
p = px.bar(data, x='Emotion', y='Probability', color="Probability") |
|
return ( |
|
p, |
|
f"π€ Transcription:\n{text}", |
|
f"## βοΈ Dominant Emotion: {data['Emotion'].values[np.argmax(np.array(data['Probability']))]}" |
|
) |
|
|
|
except Exception as e: |
|
logging.error(f"Processing failed: {e}") |
|
data['Probability'] = [0] * data.shape[0] |
|
p = px.bar(data, x='Emotion', y='Probability', color="Probability") |
|
return ( |
|
p, |
|
"β Error processing audio", |
|
"β οΈ Processing Error" |
|
) |
|
|
|
def create_demo_text(): |
|
with gr.Blocks(theme='Nymbo/rounded-gradient', css=".gradio-container {background-color: #F0F8FF}", title="Emotion Detection") as demo: |
|
gr.Markdown("# Text-based bilingual emotion recognition") |
|
|
|
with gr.Row(): |
|
text_input = gr.Textbox(label="Write Text") |
|
|
|
with gr.Row(): |
|
top_emotion = gr.Markdown("## βοΈ Dominant Emotion: Waiting for input ...", |
|
elem_classes="dominant-emotion") |
|
|
|
with gr.Row(): |
|
text_plot = gr.Plot(label="Text Analysis") |
|
|
|
text_input.change(fn=plotly_plot_text, inputs=text_input, outputs=[text_plot, top_emotion]) |
|
return demo |
|
|
|
def create_demo_audio(): |
|
with gr.Blocks(theme='Nymbo/rounded-gradient', css=".gradio-container {background-color: #F0F8FF}", title="Emotion Detection") as demo: |
|
gr.Markdown("# Text-based bilingual emotion recognition with audio transcription") |
|
|
|
with gr.Row(): |
|
audio_input = gr.Audio( |
|
sources=["upload", "microphone"], |
|
type="filepath", |
|
label="Record or Upload Audio", |
|
format="wav", |
|
interactive=True |
|
) |
|
with gr.Row(): |
|
top_emotion = gr.Markdown("## βοΈ Dominant Emotion: Waiting for input ...", |
|
elem_classes="dominant-emotion") |
|
|
|
with gr.Row(): |
|
text_plot = gr.Plot(label="Text Analysis") |
|
|
|
transcription = gr.Textbox( |
|
label="π Transcription Results", |
|
placeholder="Transcribed text will appear here...", |
|
lines=3, |
|
max_lines=6 |
|
) |
|
audio_input.change(fn=plotly_plot_audio, inputs=audio_input, outputs=[text_plot, transcription, top_emotion]) |
|
return demo |
|
|
|
def create_demo(): |
|
text = create_demo_text() |
|
audio = create_demo_audio() |
|
demo = gr.TabbedInterface( |
|
[text, audio], |
|
["Text Prediction", "Transcribed Audio Prediction"], |
|
) |
|
return demo |
|
|
|
|
|
if __name__ == "__main__": |
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
model = Mamba(num_layers = 2, d_input = 1024, d_model = 512, num_classes=7, model_name='jina', pooling=None).to(device) |
|
checkpoint = torch.load("Mamba_jina_checkpoint.pth", map_location=torch.device('cpu')) |
|
model.load_state_dict(checkpoint['model_state_dict']) |
|
|
|
demo = create_demo() |
|
demo.launch() |