File size: 2,684 Bytes
8362005
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
import asyncio
from typing import AsyncGenerator

import numpy as np
import torch
from dia.model import Dia
from loguru import logger
from pydantic import BaseModel, Field

from pipecat.frames.frames import (
    ErrorFrame,
    Frame,
    TTSAudioRawFrame,
    TTSStartedFrame,
    TTSStoppedFrame,
)
from pipecat.services.tts_service import TTSService


class DiaTTSService(TTSService):
    """TTS service for Dia.
    This service uses Dia to generate speech.
    It does not support streaming and will generate the entire audio at once.
    """

    class InputParams(BaseModel):
        """Configuration parameters for Dia TTS service."""

        use_torch_compile: bool = Field(False)
        verbose: bool = Field(False)

    def __init__(
        self,
        *,
        model_name: str = "nari-labs/Dia-1.6B",
        compute_dtype: str = "float32",
        device: str = "cpu",
        sample_rate: int = 24000,
        params: InputParams = InputParams(),
        **kwargs,
    ):
        """Initialize Dia TTS service."""
        super().__init__(sample_rate=sample_rate, **kwargs)
        logger.info(f"Initializing Dia TTS service with model: {model_name}")

        torch_device = torch.device(device)

        self._model = Dia.from_pretrained(
            model_name, compute_dtype=compute_dtype, device=torch_device
        )
        self._settings = params.dict()
        logger.info("Dia TTS service initialized")

    def can_generate_metrics(self) -> bool:
        return True

    async def run_tts(self, text: str) -> AsyncGenerator[Frame, None]:
        logger.debug(f"Generating TTS for: [{text}]")
        try:
            await self.start_ttfb_metrics()
            yield TTSStartedFrame()

            loop = asyncio.get_running_loop()

            await self.start_tts_usage_metrics(text)

            output = await loop.run_in_executor(
                None,
                self._model.generate,
                text,
                self._settings["use_torch_compile"],
                self._settings["verbose"],
            )

            audio_tensor = output["audio_tensor"]

            # The tensor is float32 in range [-1, 1], shape (1, N).
            # Convert to int16 bytes for pipecat.
            audio_data = (audio_tensor.cpu().numpy() * 32767).astype(np.int16).tobytes()

            yield TTSAudioRawFrame(
                audio=audio_data,
                sample_rate=self.sample_rate,
                num_channels=1,
            )

            yield TTSStoppedFrame()
        except Exception as e:
            logger.error(f"{self} exception: {e}", exc_info=True)
            yield ErrorFrame(f"Error generating audio: {str(e)}")