Liusuthu's picture
Update app_utils.py
f837120 verified
raw
history blame contribute delete
No virus
11.7 kB
import gradio as gr
import torch
import time
import numpy as np
import mediapipe as mp
from PIL import Image
import cv2
# from pytorch_grad_cam.utils.image import show_cam_on_image
import scipy.io.wavfile as wav
# Importing necessary components for the Gradio app
from model import pth_model_static, pth_model_dynamic, cam, pth_processing
from face_utils import get_box, display_info
from config import DICT_EMO, config_data
from plot import statistics_plot
from moviepy.editor import AudioFileClip
import soundfile as sf
import torchaudio
from speechbrain.pretrained.interfaces import foreign_class
from paraformer import AudioReader, CttPunctuator, FSMNVad, ParaformerOffline
from gradio_client import Client
##############################################################################################
client = Client("Liusuthu/TextDepression")
mp_face_mesh = mp.solutions.face_mesh
classifier = foreign_class(
source="pretrained_models/local-speechbrain/emotion-recognition-wav2vec2-IEMOCAP", # ".\\emotion-recognition-wav2vec2-IEMOCAP"
pymodule_file="custom_interface.py",
classname="CustomEncoderWav2vec2Classifier",
savedir="pretrained_models/local-speechbrain/emotion-recognition-wav2vec2-IEMOCAP",
)
ASR_model = ParaformerOffline()
vad = FSMNVad()
punc = CttPunctuator()
#########################################################################################
def text_api(text:str):
result = client.predict(
text, # str in '输入文字' Textbox component
api_name="/predict",
)
return result
#######################################################################
#规范函数,只管值输入输出:
def text_score(text):
if text==None:
gr.Warning("提交内容为空!")
else:
string=text_api(text)
part1 = str.partition(string, r"text")
want1 = part1[2]
label = want1[4:6]
part2 = str.partition(string, r"probability")
want2 = part2[2]
prob = float(want2[3:-4])
if label=="正向":
score=-np.log10(prob*10)
else:
score=np.log10(prob*10)
# print("from func:text_score————,text:",text,",score:",score)
return text,score
def speech_score(audio):
if audio==None:
gr.Warning("提交内容为空!请等待音频加载完毕后再尝试提交!")
else:
print(type(audio))
print(audio)
sample_rate, signal = audio # 这是语音的输入
signal = signal.astype(np.float32)
signal /= np.max(np.abs(signal))
sf.write("data/a.wav", signal, sample_rate)
signal, sample_rate = torchaudio.load("data/a.wav")
signal1 = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=16000)(
signal
)
torchaudio.save("data/out.wav", signal1, 16000, encoding="PCM_S", bits_per_sample=16)
Audio = "data/out.wav"
speech, sample_rate = AudioReader.read_wav_file(Audio)
if signal == "none":
return "none", "none", "haha"
else:
segments = vad.segments_offline(speech)
text_results = ""
for part in segments:
_result = ASR_model.infer_offline(
speech[part[0] * 16 : part[1] * 16], hot_words="任意热词 空格分开"
)
text_results += punc.punctuate(_result)[0]
out_prob, score, index, text_lab = classifier.classify_batch(signal1)
print("from func:speech_score————type and value of prob:")
print(type(out_prob.squeeze(0).numpy()))
print(out_prob.squeeze(0).numpy())
print("from func:speech_score————type and value of resul_label:")
print(type(text_lab[-1]))
print(text_lab[-1])
#return text_results, out_prob.squeeze(0).numpy(), text_lab[-1], Audio
prob=out_prob.squeeze(0).numpy()
#print(prob)
score2=10*prob[0]-10*prob[1]
if score2>=0:
score2=np.log10(score2)
else:
score2=-np.log10(-score2)
# print("from func:speech_score————score2:",score2)
# print("from func:speech_score————",text_lab[-1])
text,score1=text_score(text_results)
# # text_emo=str(get_text_score(text_results))
# print("from func:speech_score————text:",text,",score1:",score1)
score=(2/3)*score1+(1/3)*score2
return text,score
def video_score(video):
if video==None:
gr.Warning("提交内容为空!请等待视频加载完毕后再尝试提交!")
else:
cap = cv2.VideoCapture(video)
w = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
h = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
fps = np.round(cap.get(cv2.CAP_PROP_FPS))
path_save_video_face = 'result_face.mp4'
vid_writer_face = cv2.VideoWriter(path_save_video_face, cv2.VideoWriter_fourcc(*'mp4v'), fps, (224, 224))
# path_save_video_hm = 'result_hm.mp4'
# vid_writer_hm = cv2.VideoWriter(path_save_video_hm, cv2.VideoWriter_fourcc(*'mp4v'), fps, (224, 224))
lstm_features = []
count_frame = 1
count_face = 0
probs = []
frames = []
last_output = None
last_heatmap = None
cur_face = None
with mp_face_mesh.FaceMesh(
max_num_faces=1,
refine_landmarks=False,
min_detection_confidence=0.5,
min_tracking_confidence=0.5) as face_mesh:
while cap.isOpened():
_, frame = cap.read()
if frame is None: break
frame_copy = frame.copy()
frame_copy.flags.writeable = False
frame_copy = cv2.cvtColor(frame_copy, cv2.COLOR_BGR2RGB)
results = face_mesh.process(frame_copy)
frame_copy.flags.writeable = True
if results.multi_face_landmarks:
for fl in results.multi_face_landmarks:
startX, startY, endX, endY = get_box(fl, w, h)
cur_face = frame_copy[startY:endY, startX: endX]
if count_face%config_data.FRAME_DOWNSAMPLING == 0:
cur_face_copy = pth_processing(Image.fromarray(cur_face))
with torch.no_grad():
features = torch.nn.functional.relu(pth_model_static.extract_features(cur_face_copy)).detach().numpy()
# grayscale_cam = cam(input_tensor=cur_face_copy)
# grayscale_cam = grayscale_cam[0, :]
# cur_face_hm = cv2.resize(cur_face,(224,224), interpolation = cv2.INTER_AREA)
# cur_face_hm = np.float32(cur_face_hm) / 255
# heatmap = show_cam_on_image(cur_face_hm, grayscale_cam, use_rgb=False)
# last_heatmap = heatmap
if len(lstm_features) == 0:
lstm_features = [features]*10
else:
lstm_features = lstm_features[1:] + [features]
lstm_f = torch.from_numpy(np.vstack(lstm_features))
lstm_f = torch.unsqueeze(lstm_f, 0)
with torch.no_grad():
output = pth_model_dynamic(lstm_f).detach().numpy()
last_output = output
if count_face == 0:
count_face += 1
else:
if last_output is not None:
output = last_output
# heatmap = last_heatmap
elif last_output is None:
output = np.empty((1, 7))
output[:] = np.nan
probs.append(output[0])
frames.append(count_frame)
else:
if last_output is not None:
lstm_features = []
empty = np.empty((7))
empty[:] = np.nan
probs.append(empty)
frames.append(count_frame)
if cur_face is not None:
# heatmap_f = display_info(heatmap, 'Frame: {}'.format(count_frame), box_scale=.3)
cur_face = cv2.cvtColor(cur_face, cv2.COLOR_RGB2BGR)
cur_face = cv2.resize(cur_face, (224,224), interpolation = cv2.INTER_AREA)
cur_face = display_info(cur_face, 'Frame: {}'.format(count_frame), box_scale=.3)
vid_writer_face.write(cur_face)
# vid_writer_hm.write(heatmap_f)
count_frame += 1
if count_face != 0:
count_face += 1
vid_writer_face.release()
# vid_writer_hm.release()
stat = statistics_plot(frames, probs)
if not stat:
return None, None
#for debug
print("from func:video_score————")
print(type(frames))
print(frames)
print(type(probs))
print(probs)
# to calculate scores
nan=float('nan')
s1 = 0
s2 = 0
s3 = 0
# s4 = 0
# s5 = 0
# s6 = 0
# s7 = 0
frames_len=len(frames)
for i in range(frames_len):
if np.isnan(probs[i][0]):
frames_len=frames_len-1
else:
s1=s1+probs[i][0]
s2=s2+probs[i][1]
s3=s3+probs[i][2]
# s4=s4+probs[i][3]
# s5=s5+probs[i][4]
# s6=s6+probs[i][5]
# s7=s7+probs[i][6]
s1=s1/frames_len
s2=s2/frames_len
s3=s3/frames_len
# s4=s4/frames_len
# s5=s5/frames_len
# s6=s6/frames_len
# s7=s7/frames_len
# scores=[s1,s2,s3,s4,s5,s6,s7]
# scores_str=str(scores)
# score1=0*scores[0]-8*scores[1]+4*scores[2]+0*scores[3]+2*scores[4]+2*scores[5]+4*scores[6]
#print("from func:video_score————score1=",score1)
#print("from func:video_score————logs:")
# with open("local_data/data.txt",'a', encoding="utf8") as f:
# f.write(scores_str+'\n')
# with open("local_data/data.txt",'r', encoding="utf8") as f:
# for i in f:
# print(i)
print(str([s1,s2,s3]))
if s1>=0.4:
score1=0
else:
if s2>=s3:
score1=-1
else:
score1=+1
#trans the audio file
my_audio_clip = AudioFileClip(video)
my_audio_clip.write_audiofile("data/audio.wav",ffmpeg_params=["-ac","1"])
audio = wav.read('data/audio.wav')
text,score2=speech_score(audio)
#print("from func:video_score————text:",text)
score=(score1+6*score2)/7
#print("from func:video_score————score:",score)
return text,score
#######################################################################