Dddrl commited on
Commit
2a4fd8f
·
verified ·
1 Parent(s): 72e6835

Upload 3 files

Browse files
app.py ADDED
@@ -0,0 +1,102 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ import librosa
5
+ import numpy as np
6
+ import gradio as gr
7
+ import openai
8
+ import os
9
+
10
+ # Emotion categories
11
+ emotions = ["Neutral", "Happy", "Angry", "Sad", "Surprise"]
12
+
13
+ # CNN model definition
14
+ class CNN(nn.Module):
15
+ def __init__(self, num_classes):
16
+ super(CNN, self).__init__()
17
+ self.name = "CNN"
18
+ self.conv1 = nn.Conv1d(in_channels=768, out_channels=256, kernel_size=3, padding=1)
19
+ self.bn1 = nn.BatchNorm1d(256)
20
+ self.pool = nn.AdaptiveMaxPool1d(output_size=96)
21
+ self.conv2 = nn.Conv1d(in_channels=256, out_channels=128, kernel_size=3, padding=1)
22
+ self.bn2 = nn.BatchNorm1d(128)
23
+ self.conv3 = nn.Conv1d(in_channels=128, out_channels=64, kernel_size=3, padding=1)
24
+ self.bn3 = nn.BatchNorm1d(64)
25
+ self.fc1 = nn.Linear(64 * 96, 128)
26
+ self.dropout = nn.Dropout(0.5)
27
+ self.fc2 = nn.Linear(128, num_classes)
28
+
29
+ def forward(self, x):
30
+ x = x.unsqueeze(1)
31
+ x = x.permute(0, 2, 1)
32
+ x = F.relu(self.bn1(self.conv1(x)))
33
+ x = self.pool(x)
34
+ x = F.relu(self.bn2(self.conv2(x)))
35
+ x = self.pool(x)
36
+ x = F.relu(self.bn3(self.conv3(x)))
37
+ x = self.pool(x)
38
+ x = x.view(x.size(0), -1)
39
+ x = F.relu(self.fc1(x))
40
+ x = self.dropout(x)
41
+ x = self.fc2(x)
42
+ return x
43
+
44
+ # Load the trained model
45
+ model = CNN(num_classes=5)
46
+ model.load_state_dict(torch.load("best_model.pth", map_location="cpu"))
47
+ model.eval()
48
+
49
+ # Extract features from audio file
50
+ def extract_feature(audio_path):
51
+ y, sr = librosa.load(audio_path, sr=16000)
52
+ mfcc = librosa.feature.mfcc(y=y, sr=sr, n_mfcc=40)
53
+ max_len = 200
54
+ if mfcc.shape[1] > max_len:
55
+ mfcc = mfcc[:, :max_len]
56
+ else:
57
+ pad_width = max_len - mfcc.shape[1]
58
+ mfcc = np.pad(mfcc, ((0, 0), (0, pad_width)), mode='constant')
59
+ feature = np.tile(mfcc, (int(768 / 40), 1))
60
+ feature = torch.tensor(feature, dtype=torch.float32).unsqueeze(0)
61
+ return feature
62
+
63
+ # Full pipeline: emotion detection + GPT response
64
+ def predict_and_reply(audio_path):
65
+ feature = extract_feature(audio_path)
66
+ with torch.no_grad():
67
+ output = model(feature)
68
+ pred = torch.argmax(output, dim=1).item()
69
+ emotion = emotions[pred]
70
+
71
+ prompt = f"The user sounds {emotion.lower()}. What would you like to say to them?"
72
+
73
+ try:
74
+ openai.api_key = os.getenv("OPENAI_API_KEY", "your-openai-api-key") # Replace with real key or env var
75
+ response = openai.ChatCompletion.create(
76
+ model="gpt-3.5-turbo",
77
+ messages=[
78
+ {"role": "system", "content": "You are an empathetic AI assistant."},
79
+ {"role": "user", "content": prompt}
80
+ ]
81
+ )
82
+ reply = response['choices'][0]['message']['content']
83
+ except Exception as e:
84
+ reply = f"❌ GPT Error: {str(e)}"
85
+
86
+ return f"🎧 Detected Emotion: **{emotion}**\n\n💬 GPT Says:\n{reply}"
87
+
88
+ #️ Gradio app layout
89
+ with gr.Blocks(theme=gr.themes.Soft()) as demo:
90
+ gr.Markdown("## 🎙️ 情绪检测 + 聊天机器人")
91
+ gr.Markdown("上传或录制一段简短的语音片段,我会识别你的情绪,并请求 GPT 做出共情的回应。")
92
+
93
+ with gr.Row():
94
+ with gr.Column():
95
+ audio_input = gr.Audio(label="🎧 语音输入", type="filepath", format="wav")
96
+ submit_btn = gr.Button("🚀 提交")
97
+ with gr.Column():
98
+ output_text = gr.Markdown(label="💬 GPT 回复")
99
+
100
+ submit_btn.click(fn=predict_and_reply, inputs=audio_input, outputs=output_text)
101
+
102
+ demo.launch()
best_model_CNN_bs32_lr0.0005_epoch9_acc0.9238.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c1ff55da3574c0126d1ed0970a0e0584d4a3f0aa9e668562e95043b33bdbf946
3
+ size 6017753
requirements.txt ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ torch
2
+ torchaudio
3
+ torchvision
4
+ transformers
5
+ librosa
6
+ matplotlib
7
+ numpy
8
+ openai
9
+ pandas
10
+ tqdm
11
+ scikit-learn
12
+ gradio