jujutech commited on
Commit
59f0d90
·
verified ·
1 Parent(s): 80ee7cf

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +24 -8
app.py CHANGED
@@ -4,8 +4,8 @@ import torch
4
  import librosa
5
 
6
  # Load the model and processor
7
- processor = Wav2Vec2Processor.from_pretrained("SpeechResearch/whisper-ft-normal")
8
- model = Wav2Vec2ForCTC.from_pretrained("SpeechResearch/whisper-ft-normal")
9
 
10
  def transcribe_speech(audio_path):
11
  speech, _ = librosa.load(audio_path, sr=16000)
@@ -16,12 +16,28 @@ def transcribe_speech(audio_path):
16
  transcription = processor.batch_decode(predicted_ids)
17
  return transcription[0]
18
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
19
  def pipe(text, voice, image_in):
20
- # Assuming voice is a file path to the audio file
21
- transcription = transcribe_speech(voice)
22
- # Now use this transcription with your get_dreamtalk function
23
- video = get_dreamtalk(image_in, transcription)
24
- return video
 
 
25
 
26
  with gr.Blocks() as demo:
27
  with gr.Column():
@@ -48,4 +64,4 @@ with gr.Blocks() as demo:
48
  outputs=[video_o],
49
  concurrency_limit=3
50
  )
51
- demo.queue(max_size=10).launch(show_error=True, show_api=False)
 
4
  import librosa
5
 
6
  # Load the model and processor
7
+ processor = Wav2Vec2Processor.from_pretrained("facebook/wav2vec2-large-960h")
8
+ model = Wav2Vec2ForCTC.from_pretrained("facebook/wav2vec2-large-960h")
9
 
10
  def transcribe_speech(audio_path):
11
  speech, _ = librosa.load(audio_path, sr=16000)
 
16
  transcription = processor.batch_decode(predicted_ids)
17
  return transcription[0]
18
 
19
+ def get_dreamtalk(image_in, speech):
20
+ try:
21
+ client = Client("https://fffiloni-dreamtalk.hf.space/")
22
+ result = client.predict(
23
+ speech, # filepath in 'Audio input' Audio component
24
+ image_in, # filepath in 'Image' Image component
25
+ "M030_front_neutral_level1_001.mat", # Literal in 'emotional style' Dropdown component
26
+ api_name="/infer"
27
+ )
28
+ return result['video']
29
+ except Exception as e:
30
+ print(f"Error in get_dreamtalk: {e}")
31
+ raise gr.Error(f"Error in get_dreamtalk: {str(e)}")
32
+
33
  def pipe(text, voice, image_in):
34
+ try:
35
+ speech = transcribe_speech(voice)
36
+ video = get_dreamtalk(image_in, speech)
37
+ return video
38
+ except Exception as e:
39
+ print(f"An error occurred while processing: {e}")
40
+ raise gr.Error(f"An error occurred while processing: {str(e)}")
41
 
42
  with gr.Blocks() as demo:
43
  with gr.Column():
 
64
  outputs=[video_o],
65
  concurrency_limit=3
66
  )
67
+ demo.queue(max_size=10).launch(show_error=True, show_api=False)