sachin6624 commited on
Commit
b3e0ef2
·
verified ·
1 Parent(s): aa13a8d

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +148 -0
app.py ADDED
@@ -0,0 +1,148 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ from unsloth import FastLanguageModel
4
+ from snac import SNAC
5
+ import numpy as np
6
+
7
+ # Set device globally for the app
8
+ device = "cuda" if torch.cuda.is_available() else "cpu"
9
+
10
+ # Load models (globally, once when app starts)
11
+ model_name = "sachin6624/orpheus-3b-0.1-ft-malayalam-3epoch"
12
+ print(f"Loading LLM {model_name} on {device}...")
13
+ model, tokenizer = FastLanguageModel.from_pretrained(
14
+ model_name=model_name,
15
+ max_seq_length=2048,
16
+ dtype=None,
17
+ load_in_4bit=False, # Use True for 4-bit loading to reduce memory if needed
18
+ )
19
+ model.to(device)
20
+ FastLanguageModel.for_inference(model)
21
+ print("LLM loaded.")
22
+
23
+ print(f"Loading SNAC decoder on {device}...")
24
+ snac_model = SNAC.from_pretrained("hubertsiuzdak/snac_24khz")
25
+ snac_model = snac_model.to(device)
26
+ # Explicitly define sample rate as the model name 'snac_24khz' suggests 24000 Hz
27
+ snac_model_sample_rate = 24000
28
+ print("SNAC decoder loaded. Assumed sample rate:", snac_model_sample_rate)
29
+
30
+ # Define tokens on the selected device
31
+ start_token = torch.tensor([[128259]], dtype=torch.int64, device=device) # Start of human
32
+ end_tokens = torch.tensor([[128009, 128260]], dtype=torch.int64, device=device) # End of text, End of human
33
+ token_to_find = 128257
34
+ token_to_remove = 128258
35
+
36
+ def redistribute_codes(code_list):
37
+ """
38
+ Redistributes SNAC codes into layers and decodes them to audio.
39
+ `code_list` is expected to be a list of Python integers.
40
+ """
41
+ if not code_list:
42
+ raise ValueError("Input code_list to redistribute_codes is empty.")
43
+
44
+ layer_1 = []
45
+ layer_2 = []
46
+ layer_3 = []
47
+
48
+ # Ensure there are enough codes to form full groups of 7
49
+ processed_len = (len(code_list) // 7) * 7
50
+ if processed_len == 0:
51
+ raise ValueError("code_list is too short to form any valid SNAC layers.")
52
+
53
+ for i in range(processed_len // 7):
54
+ base_idx = 7*i
55
+ layer_1.append(code_list[base_idx])
56
+ layer_2.append(code_list[base_idx+1]-4096)
57
+ layer_3.append(code_list[base_idx+2]-(2*4096))
58
+ layer_3.append(code_list[base_idx+3]-(3*4096))
59
+ layer_2.append(code_list[base_idx+4]-(4*4096))
60
+ layer_3.append(code_list[base_idx+5]-(5*4096))
61
+ layer_3.append(code_list[base_idx+6]-(6*4096))
62
+
63
+ # Convert lists of Python integers to torch tensors on the specified device
64
+ codes = [
65
+ torch.tensor(layer_1, dtype=torch.long, device=device).unsqueeze(0),
66
+ torch.tensor(layer_2, dtype=torch.long, device=device).unsqueeze(0),
67
+ torch.tensor(layer_3, dtype=torch.long, device=device).unsqueeze(0)
68
+ ]
69
+
70
+ audio_hat = snac_model.decode(codes)
71
+ return audio_hat
72
+
73
+ def generate_audio(prompt: str):
74
+ """
75
+ Generates audio from a given text prompt.
76
+ """
77
+ if not prompt or not prompt.strip():
78
+ raise gr.Error("Please enter a valid text prompt.")
79
+
80
+ try:
81
+ # Tokenize the prompt and prepare input_ids
82
+ input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(device)
83
+ # Concatenate start/end tokens to the input_ids
84
+ modified_input_ids = torch.cat([start_token, input_ids, end_tokens], dim=1)
85
+
86
+ # Create an attention mask for the unpadded input
87
+ attention_mask = torch.ones_like(modified_input_ids, dtype=torch.long, device=device)
88
+
89
+ # Generate IDs using the model
90
+ generated_ids = model.generate(
91
+ input_ids=modified_input_ids,
92
+ attention_mask=attention_mask,
93
+ max_new_tokens=1200,
94
+ do_sample=True,
95
+ temperature=0.6,
96
+ top_p=0.95,
97
+ repetition_penalty=1.1,
98
+ num_return_sequences=1,
99
+ eos_token_id=128258,
100
+ use_cache = True
101
+ )
102
+
103
+ # Post-process generated_ids to extract SNAC codes
104
+ token_indices = (generated_ids == token_to_find).nonzero(as_tuple=True)
105
+
106
+ cropped_tensor = generated_ids
107
+ if len(token_indices[1]) > 0:
108
+ last_occurrence_idx = token_indices[1][-1].item()
109
+ cropped_tensor = generated_ids[:, last_occurrence_idx+1:]
110
+
111
+ # Filter out token_to_remove (EOS token for generation)
112
+ processed_row_tensor = cropped_tensor[cropped_tensor != token_to_remove]
113
+
114
+ row_length = processed_row_tensor.size(0)
115
+ new_length = (row_length // 7) * 7 # Ensure length is a multiple of 7 for redistribution
116
+
117
+ if new_length == 0:
118
+ raise gr.Error("Generated response was too short to form valid audio codes. Try a different prompt or longer text.")
119
+
120
+ trimmed_row = processed_row_tensor[:new_length]
121
+ # Convert tensor elements to Python integers and apply offset
122
+ trimmed_row_list = [t.item() - 128266 for t in trimmed_row]
123
+
124
+ samples = redistribute_codes(trimmed_row_list)
125
+ audio_output = samples.detach().squeeze().to("cpu").numpy()
126
+
127
+ return (snac_model_sample_rate, audio_output)
128
+
129
+ except Exception as e:
130
+ raise gr.Error(f"An error occurred during audio generation: {e}")
131
+
132
+ # Gradio Interface setup
133
+ iface = gr.Interface(
134
+ fn=generate_audio,
135
+ inputs=gr.Textbox(lines=2, placeholder="Enter your prompt here...", label="Text Prompt (Malayalam)"),
136
+ outputs=gr.Audio(label="Generated Audio", autoplay=True),
137
+ title="Malayalam Text-to-Speech (Orpheus-3B & SNAC)",
138
+ description="Generate speech from Malayalam text using the fine-tuned Orpheus-3B model and SNAC for audio generation.",
139
+ examples=[["എങ്ങനെയുണ്ട് എന്റെ കുട്ടി?, <giggles>."],
140
+ ["നമസ്കാരം, നിങ്ങൾക്ക് സുഖമാണോ?"]],
141
+ )
142
+
143
+ # Use flagging_mode instead of allow_flagging for Gradio 4.0+
144
+ iface.flagging_mode = 'never'
145
+
146
+ # Launch the Gradio app if the script is run directly
147
+ if __name__ == "__main__":
148
+ iface.launch()