File size: 5,024 Bytes
28ef647
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
import torch
import torchaudio
import numpy as np
from transformers import AutoTokenizer, AutoModelForCausalLM
from livekit import rtc
import asyncio
import os

class EndpointHandler:
    def __init__(self, path: str = ""):
        # Load the Orpheus TTS model and tokenizer from the given path (Hub repository).
        self.device = "cuda" if torch.cuda.is_available() else "cpu"
        path = "atharva27/orpheus"
        self.tokenizer = AutoTokenizer.from_pretrained(path)
        self.model = AutoModelForCausalLM.from_pretrained(path, torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32)
        self.model.to(self.device)
        self.model.eval()

    def __call__(self, data: dict) -> list:
        # Extract input text and optional voice and LiveKit parameters.
        text_input = data.get("inputs") or data.get("text") or ""
        if not isinstance(text_input, str) or text_input.strip() == "":
            raise ValueError("No text input provided for TTS")
        voice = data.get("voice", "tara")  # default voice (e.g., "tara")
        
        # Format prompt with voice name (Orpheus expects prompts like "voice: text").
        prompt = f"{voice}: {text_input}"
        
        # Encode prompt and generate output tokens with the TTS model.
        input_ids = self.tokenizer(prompt, return_tensors="pt").input_ids.to(self.device)
        generate_kwargs = {
            "max_new_tokens": 1024,               # allow sufficient tokens for audio output
            "do_sample": True,
            "temperature": 0.8,
            "top_p": 0.95,
            "repetition_penalty": 1.1,           # >=1.1 for stable speech generation
            "pad_token_id": self.tokenizer.eos_token_id,
        }
        output_ids = self.model.generate(input_ids, **generate_kwargs)
        # The generated sequence includes the prompt; isolate newly generated tokens:
        generated_tokens = output_ids[0, input_ids.size(1):]
        output_text = self.tokenizer.decode(generated_tokens, skip_special_tokens=False)
        
        # Extract audio token IDs (assume tokens are in the output_text)
        # This is a placeholder for token extraction, replace with actual logic.
        audio_token_ids = [int(m) for m in output_text.split()]

        # Example: convert the audio token IDs to waveform data
        waveform = self.generate_waveform_from_tokens(audio_token_ids)

        # Save or stream waveform
        torchaudio.save("output_audio.wav", waveform, 24000)  # Save as a 24 kHz audio file

        # For real-time streaming, we will use LiveKit to stream the audio
        lk_url = data.get("livekit_url")
        lk_token = data.get("livekit_token")
        room_name = data.get("livekit_room", "default-room")

        # Streaming logic
        asyncio.run(self.stream_audio(lk_url, lk_token, room_name, waveform))

        return [{"status": "success"}]

    def generate_waveform_from_tokens(self, audio_token_ids):
        """
        Convert audio tokens into a waveform (this part is for demonstration).
        You should implement a proper method to decode tokens to actual audio.
        """
        # Here we're simulating the waveform by generating random data based on the tokens
        # Replace this logic with actual audio generation
        num_samples = len(audio_token_ids) * 100  # Estimate number of samples based on tokens
        waveform = torch.randn(1, num_samples)  # Simulate random audio waveform
        return waveform

    async def stream_audio(self, lk_url, lk_token, room_name, waveform):
        room = rtc.Room()
        try:
            await room.connect(lk_url, lk_token, options=rtc.RoomOptions(auto_subscribe=True))
        except Exception as e:
            return f"Failed to connect to LiveKit: {e}"

        # Create an audio track for streaming the TTS output
        source = rtc.AudioSource(sample_rate=24000, num_channels=1)
        track = rtc.LocalAudioTrack.create_audio_track("tts-audio", source)
        await room.local_participant.publish_track(track, rtc.TrackPublishOptions(name="TTS Audio"))

        # Stream the waveform data in chunks for real-time playback
        frame_duration = 0.05  # 50 ms per frame
        frame_samples = int(24000 * frame_duration)  # 50 ms of audio at 24 kHz sample rate
        total_samples = waveform.size(1)
        for start in range(0, total_samples, frame_samples):
            end = min(start + frame_samples, total_samples)
            chunk = waveform[:, start:end].numpy().astype(np.int16)  # Convert chunk to 16-bit PCM

            # Create an AudioFrame and send to LiveKit
            audio_frame = rtc.AudioFrame.create(24000, 1, len(chunk))
            np.copyto(audio_frame.data, chunk)
            await source.capture_frame(audio_frame)

            # Sleep to maintain real-time pace (synchronize with frame duration)
            await asyncio.sleep(frame_duration)

        # Disconnect from the room after streaming is finished
        await room.disconnect()