Hobis's picture
Update app.py
2eaaa08
import gradio as gr
import os
import torchaudio
import torch
import numpy as np
from hubert.hubert_manager import HuBERTManager
from hubert.pre_kmeans_hubert import CustomHubert
from hubert.customtokenizer import CustomTokenizer
from encodec import EncodecModel
from encodec.utils import convert_audio
def process_audio(audio_file):
hubert_model = CustomHubert(checkpoint_path='data/models/hubert/hubert.pt')
wav, sr = torchaudio.load(audio_file)
if wav.shape[0] == 2:
wav = wav.mean(0, keepdim=True)
semantic_vectors = hubert_model.forward(wav, input_sample_hz=sr)
tokenizer = CustomTokenizer.load_from_checkpoint('data/models/hubert/tokenizer.pth')
semantic_tokens = tokenizer.get_token(semantic_vectors)
model = EncodecModel.encodec_model_24khz()
model.set_target_bandwidth(6.0)
wav = convert_audio(wav, sr, model.sample_rate, model.channels)
wav = wav.unsqueeze(0)
with torch.no_grad():
encoded_frames = model.encode(wav)
codes = torch.cat([encoded[0] for encoded in encoded_frames], dim=-1).squeeze()
fine_prompt = codes
coarse_prompt = fine_prompt[:2, :]
np.savez('helloWorld.npz', semantic_prompt=semantic_tokens, fine_prompt=fine_prompt, coarse_prompt=coarse_prompt)
def audio_file_processing(input_audio):
process_audio('audio.wav')
return "Plik audio został przetworzony i zapisany jako helloWorld.npz"
audio_input = gr.inputs.Audio(label="Wybierz plik audio")
audio_output = gr.outputs.Textbox(label="Status")
gr.Interface(fn=audio_file_processing, inputs=audio_input, outputs=audio_output).launch()