Spaces:
Running
on
Zero
Running
on
Zero
File size: 5,970 Bytes
87930ea 3268a02 87930ea 3268a02 87930ea 3268a02 87930ea 3268a02 87930ea 64e2453 87930ea 3268a02 87930ea 3268a02 87930ea |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 |
import time
import traceback
from dataclasses import dataclass, field
import gradio as gr
import librosa
import numpy as np
import soundfile as sf
import spaces
import torch
import xxhash
from datasets import Audio
from transformers import AutoModel
import io
from pydub import AudioSegment
import tempfile
from utils.vad import VadOptions, collect_chunks, get_speech_timestamps
diva_model = AutoModel.from_pretrained(
"WillHeld/DiVA-llama-3-v0-8b", trust_remote_code=True
)
resampler = Audio(sampling_rate=16_000)
@spaces.GPU
@torch.no_grad
def diva_audio(audio_input, do_sample=False, temperature=0.001, prev_outs=None):
sr, y = audio_input
x = xxhash.xxh32(bytes(y)).hexdigest()
y = y.astype(np.float32)
y /= np.max(np.abs(y))
a = resampler.decode_example(
resampler.encode_example({"array": y, "sampling_rate": sr})
)
yield from diva_model.generate_stream(
a["array"],
None,
do_sample=do_sample,
max_new_tokens=256,
init_outputs=prev_outs,
return_outputs=True,
)
def run_vad(ori_audio, sr):
_st = time.time()
try:
audio = ori_audio
audio = audio.astype(np.float32) / 32768.0
sampling_rate = 16000
if sr != sampling_rate:
audio = librosa.resample(audio, orig_sr=sr, target_sr=sampling_rate)
vad_parameters = {}
vad_parameters = VadOptions(**vad_parameters)
speech_chunks = get_speech_timestamps(audio, vad_parameters)
audio = collect_chunks(audio, speech_chunks)
duration_after_vad = audio.shape[0] / sampling_rate
if sr != sampling_rate:
# resample to original sampling rate
vad_audio = librosa.resample(audio, orig_sr=sampling_rate, target_sr=sr)
else:
vad_audio = audio
vad_audio = np.round(vad_audio * 32768.0).astype(np.int16)
vad_audio_bytes = vad_audio.tobytes()
return duration_after_vad, vad_audio_bytes, round(time.time() - _st, 4)
except Exception as e:
msg = f"[asr vad error] audio_len: {len(ori_audio)/(sr*2):.3f} s, trace: {traceback.format_exc()}"
print(msg)
return -1, ori_audio, round(time.time() - _st, 4)
def warm_up():
frames = np.ones(2048) # 1024 frames of 2 bytes each
dur, frames, tcost = run_vad(frames, 16000)
print(f"warm up done, time_cost: {tcost:.3f} s")
warm_up()
@dataclass
class AppState:
stream: np.ndarray | None = None
sampling_rate: int = 0
pause_detected: bool = False
started_talking: bool = False
stopped: bool = False
conversation: list = field(default_factory=list)
model_outs: any = None
def determine_pause(audio: np.ndarray, sampling_rate: int, state: AppState) -> bool:
"""Take in the stream, determine if a pause happened"""
temp_audio = audio
dur_vad, _, time_vad = run_vad(temp_audio, sampling_rate)
duration = len(audio) / sampling_rate
if dur_vad > 0.5 and not state.started_talking:
print("started talking")
state.started_talking = True
return False
print(f"duration_after_vad: {dur_vad:.3f} s, time_vad: {time_vad:.3f} s")
return (duration - dur_vad) > 1
def process_audio(audio: tuple, state: AppState):
if state.stream is None:
state.stream = audio[1]
state.sampling_rate = audio[0]
else:
state.stream = np.concatenate((state.stream, audio[1]))
pause_detected = determine_pause(state.stream, state.sampling_rate, state)
state.pause_detected = pause_detected
if state.pause_detected and state.started_talking:
return gr.Audio(recording=False), state
return None, state
def response(state: AppState):
if not state.pause_detected and not state.started_talking:
return AppState()
file_name = f"/tmp/{xxhash.xxh32(bytes(state.stream)).hexdigest()}.wav"
sf.write(file_name, state.stream, state.sampling_rate, format="wav")
state.conversation.append(
{"role": "user", "content": {"path": file_name, "mime_type": "audio/wav"}}
)
start = False
for resp, outs in diva_audio(
(state.sampling_rate, state.stream), prev_outs=state.model_outs
):
if not start:
state.conversation.append({"role": "assistant", "content": resp})
start = True
else:
state.conversation[-1]["content"] = resp
yield state, state.conversation
yield AppState(conversation=state.conversation, model_outs=outs), state.conversation
def start_recording_user(state: AppState):
if not state.stopped:
return gr.Audio(recording=True)
theme = gr.themes.Soft(
primary_hue=gr.themes.Color(
c100="#82000019",
c200="#82000033",
c300="#8200004c",
c400="#82000066",
c50="#8200007f",
c500="#8200007f",
c600="#82000099",
c700="#820000b2",
c800="#820000cc",
c900="#820000e5",
c950="#820000f2",
),
secondary_hue="rose",
neutral_hue="stone",
)
with gr.Blocks(theme=theme) as demo:
with gr.Row():
with gr.Column():
input_audio = gr.Audio(
label="Input Audio", sources="microphone", type="numpy"
)
with gr.Column():
chatbot = gr.Chatbot(label="Conversation", type="messages")
state = gr.State(value=AppState())
stream = input_audio.stream(
process_audio,
[input_audio, state],
[input_audio, state],
stream_every=0.50,
time_limit=30,
)
respond = input_audio.stop_recording(response, [state], [state, chatbot])
respond.then(start_recording_user, [state], [input_audio])
cancel = gr.Button("Stop Conversation", variant="stop")
cancel.click(
lambda: (AppState(stopped=True), gr.Audio(recording=False)),
None,
[state, input_audio],
cancels=[respond, stream],
)
demo.launch(share=True)
|