Kevin676 commited on
Commit
a69ae8e
1 Parent(s): 002b9bb

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +65 -7
app.py CHANGED
@@ -17,6 +17,24 @@ model = RWKV(model=model_path, strategy='cuda fp16i8 *8 -> cuda fp16')
17
  from rwkv.utils import PIPELINE, PIPELINE_ARGS
18
  pipeline = PIPELINE(model, "20B_tokenizer.json")
19
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
20
  def generate_prompt(instruction, input=None):
21
  if input:
22
  return f"""Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.
@@ -39,7 +57,9 @@ def generate_prompt(instruction, input=None):
39
  """
40
 
41
  def evaluate(
42
- instruction,
 
 
43
  # input=None,
44
  # token_count=200,
45
  # temperature=1.0,
@@ -47,13 +67,30 @@ def evaluate(
47
  # presencePenalty = 0.1,
48
  # countPenalty = 0.1,
49
  ):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
50
  args = PIPELINE_ARGS(temperature = max(0.2, float(1)), top_p = float(0.5),
51
  alpha_frequency = 0.4,
52
  alpha_presence = 0.4,
53
  token_ban = [], # ban the generation of some tokens
54
  token_stop = [0]) # stop generation whenever you see any token here
55
 
56
- instruction = instruction.strip()
57
  input=None
58
  # input = input.strip()
59
  ctx = generate_prompt(instruction, input)
@@ -87,12 +124,33 @@ def evaluate(
87
  out_last = i + 1
88
  gc.collect()
89
  torch.cuda.empty_cache()
90
- yield out_str.strip()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
91
 
92
  g = gr.Interface(
93
  fn=evaluate,
94
  inputs=[
95
- gr.components.Textbox(lines=2, label="Instruction", value="Tell me about ravens."),
 
 
96
  # gr.components.Textbox(lines=2, label="Input", placeholder="none"),
97
  # gr.components.Slider(minimum=10, maximum=200, step=10, value=150), # token_count
98
  # gr.components.Slider(minimum=0.2, maximum=2.0, step=0.1, value=1.0), # temperature
@@ -101,9 +159,9 @@ g = gr.Interface(
101
  # gr.components.Slider(0.0, 1.0, step=0.1, value=0.4), # countPenalty
102
  ],
103
  outputs=[
104
- gr.inputs.Textbox(
105
- lines=5,
106
- label="Output",
107
  )
108
  ],
109
  title="🥳💬💕 - TalktoAI,随时随地,谈天说地!",
 
17
  from rwkv.utils import PIPELINE, PIPELINE_ARGS
18
  pipeline = PIPELINE(model, "20B_tokenizer.json")
19
 
20
+ from TTS.api import TTS
21
+ tts = TTS(model_name="tts_models/multilingual/multi-dataset/your_tts", progress_bar=False, gpu=True)
22
+ import whisper
23
+ model = whisper.load_model("small")
24
+
25
+ os.system('pip install voicefixer --upgrade')
26
+ from voicefixer import VoiceFixer
27
+ voicefixer = VoiceFixer()
28
+
29
+ import torchaudio
30
+ from speechbrain.pretrained import SpectralMaskEnhancement
31
+
32
+ enhance_model = SpectralMaskEnhancement.from_hparams(
33
+ source="speechbrain/metricgan-plus-voicebank",
34
+ savedir="pretrained_models/metricgan-plus-voicebank",
35
+ run_opts={"device":"cuda"},
36
+ )
37
+
38
  def generate_prompt(instruction, input=None):
39
  if input:
40
  return f"""Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.
 
57
  """
58
 
59
  def evaluate(
60
+ upload,
61
+ audio,
62
+ # instruction,
63
  # input=None,
64
  # token_count=200,
65
  # temperature=1.0,
 
67
  # presencePenalty = 0.1,
68
  # countPenalty = 0.1,
69
  ):
70
+
71
+ audio = whisper.load_audio(audio)
72
+ audio = whisper.pad_or_trim(audio)
73
+
74
+ # make log-Mel spectrogram and move to the same device as the model
75
+ mel = whisper.log_mel_spectrogram(audio).to(model.device)
76
+
77
+ # detect the spoken language
78
+ _, probs = model.detect_language(mel)
79
+ print(f"Detected language: {max(probs, key=probs.get)}")
80
+
81
+ # decode the audio
82
+ options = whisper.DecodingOptions()
83
+ result = whisper.decode(model, mel, options)
84
+
85
+
86
+ res = []
87
  args = PIPELINE_ARGS(temperature = max(0.2, float(1)), top_p = float(0.5),
88
  alpha_frequency = 0.4,
89
  alpha_presence = 0.4,
90
  token_ban = [], # ban the generation of some tokens
91
  token_stop = [0]) # stop generation whenever you see any token here
92
 
93
+ instruction = result.text.strip()
94
  input=None
95
  # input = input.strip()
96
  ctx = generate_prompt(instruction, input)
 
124
  out_last = i + 1
125
  gc.collect()
126
  torch.cuda.empty_cache()
127
+
128
+ res.append(out_str.strip())
129
+
130
+ tts.tts_to_file(res, speaker_wav = upload, language="en", file_path="output.wav")
131
+
132
+ voicefixer.restore(input="output.wav", # input wav file path
133
+ output="audio1.wav", # output wav file path
134
+ cuda=True, # whether to use gpu acceleration
135
+ mode = 0) # You can try out mode 0, 1, or 2 to find out the best result
136
+
137
+ noisy = enhance_model.load_audio(
138
+ "audio1.wav"
139
+ ).unsqueeze(0)
140
+
141
+ enhanced = enhance_model.enhance_batch(noisy, lengths=torch.tensor([1.]))
142
+ torchaudio.save("enhanced.wav", enhanced.cpu(), 16000)
143
+
144
+ return [result.text, res, "enhanced.wav"]
145
+
146
+ # yield out_str.strip()
147
 
148
  g = gr.Interface(
149
  fn=evaluate,
150
  inputs=[
151
+ gr.Audio(source="upload", label = "请上传您喜欢的声音(wav文件)", type="filepath"),
152
+ gr.Audio(source="microphone", label = "和您的专属AI聊天吧!", type="filepath"),
153
+ # gr.components.Textbox(lines=2, label="Instruction", value="Tell me about ravens."),
154
  # gr.components.Textbox(lines=2, label="Input", placeholder="none"),
155
  # gr.components.Slider(minimum=10, maximum=200, step=10, value=150), # token_count
156
  # gr.components.Slider(minimum=0.2, maximum=2.0, step=0.1, value=1.0), # temperature
 
159
  # gr.components.Slider(0.0, 1.0, step=0.1, value=0.4), # countPenalty
160
  ],
161
  outputs=[
162
+ gr.Textbox(label="Speech to Text"),
163
+ gr.Textbox(label="Raven Output"),
164
+ gr.Audio(label="Audio with Custom Voice"),
165
  )
166
  ],
167
  title="🥳💬💕 - TalktoAI,随时随地,谈天说地!",