Lwasinam commited on
Commit
c16872d
1 Parent(s): dc0a0a5

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +120 -0
app.py ADDED
@@ -0,0 +1,120 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ import soundfile as sf
4
+ from snac import SNAC
5
+ from transformers import AutoTokenizer, AutoModelForCausalLM
6
+
7
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
8
+
9
+ def find_last_instance_of_separator(lst, element=50258):
10
+ reversed_list = lst[::-1]
11
+ try:
12
+ reversed_index = reversed_list.index(element)
13
+ return len(lst) - 1 - reversed_index
14
+ except ValueError:
15
+ raise ValueError
16
+
17
+ def reconstruct_tensors(flattened_output):
18
+ def count_elements_between_hashes(lst):
19
+ try:
20
+ first_index = lst.index(50258)
21
+ second_index = lst.index(50258, first_index + 1)
22
+ return second_index - first_index - 1
23
+ except ValueError:
24
+ return "List does not contain two '#' symbols"
25
+
26
+ def remove_elements_before_hash(flattened_list):
27
+ try:
28
+ first_hash_index = flattened_list.index(50258)
29
+ return flattened_list[first_hash_index:]
30
+ except ValueError:
31
+ return "List does not contain the symbol '#'"
32
+
33
+ def list_to_torch_tensor(tensor1):
34
+ tensor = torch.tensor(tensor1)
35
+ tensor = tensor.unsqueeze(0)
36
+ return tensor
37
+
38
+ flattened_output = remove_elements_before_hash(flattened_output)
39
+ last_index = find_last_instance_of_separator(flattened_output)
40
+ flattened_output = flattened_output[:last_index]
41
+
42
+ codes = []
43
+ tensor1 = []
44
+ tensor2 = []
45
+ tensor3 = []
46
+ tensor4 = []
47
+
48
+ n_tensors = count_elements_between_hashes(flattened_output)
49
+ if n_tensors == 7:
50
+ for i in range(0, len(flattened_output), 8):
51
+ tensor1.append(flattened_output[i+1])
52
+ tensor2.append(flattened_output[i+2])
53
+ tensor3.append(flattened_output[i+3])
54
+ tensor3.append(flattened_output[i+4])
55
+ tensor2.append(flattened_output[i+5])
56
+ tensor3.append(flattened_output[i+6])
57
+ tensor3.append(flattened_output[i+7])
58
+ codes = [list_to_torch_tensor(tensor1).to(device), list_to_torch_tensor(tensor2).to(device), list_to_torch_tensor(tensor3).to(device)]
59
+
60
+ if n_tensors == 15:
61
+ for i in range(0, len(flattened_output), 16):
62
+ tensor1.append(flattened_output[i+1])
63
+ tensor2.append(flattened_output[i+2])
64
+ tensor3.append(flattened_output[i+3])
65
+ tensor4.append(flattened_output[i+4])
66
+ tensor4.append(flattened_output[i+5])
67
+ tensor3.append(flattened_output[i+6])
68
+ tensor4.append(flattened_output[i+7])
69
+ tensor4.append(flattened_output[i+8])
70
+ tensor2.append(flattened_output[i+9])
71
+ tensor3.append(flattened_output[i+10])
72
+ tensor4.append(flattened_output[i+11])
73
+ tensor4.append(flattened_output[i+12])
74
+ tensor3.append(flattened_output[i+13])
75
+ tensor4.append(flattened_output[i+14])
76
+ tensor4.append(flattened_output[i+15])
77
+ codes = [list_to_torch_tensor(tensor1).to(device), list_to_torch_tensor(tensor2).to(device), list_to_torch_tensor(tensor3).to(device), list_to_torch_tensor(tensor4).to(device)]
78
+
79
+ return codes
80
+
81
+ def load_model():
82
+ tokenizer = AutoTokenizer.from_pretrained("Lwasinam/voicera-jenny-finetune")
83
+ model = AutoModelForCausalLM.from_pretrained("Lwasinam/voicera-jenny-finetune").to(device)
84
+ snac_model = SNAC.from_pretrained("hubertsiuzdak/snac_24khz").eval()
85
+ return model, tokenizer, snac_model
86
+
87
+ def SpeechDecoder(codes, snac_model):
88
+ codes = codes.squeeze(0).tolist()
89
+ reconstructed_codes = reconstruct_tensors(codes)
90
+ audio_hat = snac_model.to(device).decode(reconstructed_codes)
91
+ audio_path = "reconstructed_audio.wav"
92
+ sf.write(audio_path, audio_hat.squeeze().cpu().detach().numpy(), 24000)
93
+ return audio_path
94
+
95
+ def generate_audio(text, tokenizer, model, snac_model):
96
+ output_codes = []
97
+ with torch.no_grad():
98
+ input_text = text
99
+ input_ids = tokenizer(input_text, return_tensors='pt').to(device)
100
+ output_codes = model.generate(input_ids['input_ids'], attention_mask=input_ids['attention_mask'], max_length=1024,
101
+ num_beams=5, top_p=0.95, temperature=0.8, do_sample=True, repetition_penalty=2.0)
102
+ audio_path = SpeechDecoder(output_codes, snac_model)
103
+ return audio_path
104
+
105
+ def main(text):
106
+ model, tokenizer, snac_model = load_model()
107
+ audio_path = generate_audio(text, tokenizer, model, snac_model)
108
+ return audio_path
109
+
110
+ # Define the Gradio interface
111
+ iface = gr.Interface(
112
+ fn=main,
113
+ inputs=gr.inputs.Textbox(label="Enter text:", lines=2, placeholder="Type your text here..."),
114
+ outputs="audio",
115
+ title="Voicera TTS",
116
+ description="Generate speech from text using Voicera TTS model."
117
+ )
118
+
119
+ if __name__ == "__main__":
120
+ iface.launch()