fffiloni's picture
Update app.py
eae360d
import gradio as gr
import torch
from transformers import BarkModel
from optimum.bettertransformer import BetterTransformer
model = BarkModel.from_pretrained("suno/bark", torch_dtype=torch.float16)
device = "cuda:0" if torch.cuda.is_available() else "cpu"
model = model.to(device)
from transformers import AutoProcessor
processor = AutoProcessor.from_pretrained("suno/bark")
# Use bettertransform for flash attention
model = BetterTransformer.transform(model, keep_original_model=False)
# Enable CPU offload
model.enable_cpu_offload()
import numpy as np
from scipy.io.wavfile import write as write_wav
import wave
def split_text_into_sentences(text):
sentences = []
current_sentence = ''
words = text.split()
for word in words:
current_sentence += ' ' + word
if word.endswith('.'):
sentences.append(current_sentence.strip())
current_sentence = ''
if current_sentence:
sentences.append(current_sentence.strip())
return sentences
def join_wav_files(input_files, output_file):
# Open the first input file to get its parameters
with wave.open(input_files[0], 'rb') as first_file:
# Get the audio parameters from the first file
params = first_file.getparams()
# Create a new wave file for writing the joined audio
with wave.open(output_file, 'wb') as output:
output.setparams(params)
# Iterate over the input files and write their audio data to the output file
for input_file in input_files:
with wave.open(input_file, 'rb') as input:
output.writeframes(input.readframes(input.getnframes()))
def infer(text_prompt):
print("""
Cutting text in chunks
""")
text_chunks = split_text_into_sentences(text_prompt)
result = generate(text_chunks, "wav")
print(result)
output_wav = 'full_story.wav'
join_wav_files(result, output_wav)
return 'full_story.wav'
def generate(text_prompt, out_type):
text_prompt = text_prompt
inputs = processor(text_prompt, voice_preset="v2/en_speaker_6").to(device)
with torch.inference_mode():
speech_output = model.generate(**inputs)
input_waves = []
for i, speech_out in enumerate(speech_output):
audio_array = speech_out.cpu().numpy().squeeze()
print(f'AUDIO_ARRAY: {audio_array}')
# Assuming audio_array contains audio data and the sampling rate
sampling_rate = model.generation_config.sample_rate
print(f'sampling_rate: {sampling_rate}')
if out_type == "numpy":
input_waves.append(sampling_rate, audio_array)
elif out_type == "wav":
#If you want to return a WAV file :
# Ensure the audio data is properly scaled (between -1 and 1 for 16-bit audio)
audio_data = np.int16(audio_array * 32767) # Scale for 16-bit signed integer
write_wav(f"output_{i}.wav", sampling_rate, audio_data)
input_waves.append(f"output_{i}.wav")
return input_waves
with gr.Blocks() as demo:
with gr.Column():
prompt = gr.Textbox(label="prompt")
submit_btn = gr.Button("Submit")
audio_out = gr.Audio()
submit_btn.click(fn=infer, inputs=[prompt], outputs=[audio_out])
demo.launch()