File size: 1,530 Bytes
9470ace
 
 
 
 
c7c7522
9470ace
 
 
c7c7522
9470ace
c7c7522
 
 
 
 
 
 
 
9470ace
 
c7c7522
 
 
9470ace
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
# tts_utils.py
import torch
from parler_tts import ParlerTTSForConditionalGeneration
from transformers import AutoTokenizer

# Updated load_model function in tts_utils.py
def load_model():
    model = ParlerTTSForConditionalGeneration.from_pretrained(
        "ai4bharat/indic-parler-tts",
        torch_dtype=torch.float32  # Force CPU-compatible dtype
    )
    
    # Apply dynamic quantization to Linear layers
    quantized_model = torch.ao.quantization.quantize_dynamic(
        model,
        {torch.nn.Linear},  # Target layer type
        dtype=torch.qint8
    )
    
    tokenizer = AutoTokenizer.from_pretrained("ai4bharat/indic-parler-tts")
    description_tokenizer = AutoTokenizer.from_pretrained("ai4bharat/indic-parler-tts")
    
    return quantized_model, tokenizer, description_tokenizer


def generate_speech(text, voice_prompt, model, tokenizer, description_tokenizer):
    device = "cuda" if torch.cuda.is_available() else "cpu"
    model = model.to(device)
    
    description_input_ids = description_tokenizer(
        voice_prompt, 
        return_tensors="pt"
    ).to(device)
    
    prompt_input_ids = tokenizer(text, return_tensors="pt").to(device)
    
    generation = model.generate(
        input_ids=description_input_ids.input_ids,
        attention_mask=description_input_ids.attention_mask,
        prompt_input_ids=prompt_input_ids.input_ids,
        prompt_attention_mask=prompt_input_ids.attention_mask,
        max_new_tokens=1024
    )
    
    return generation.cpu().numpy().squeeze()