diff --git a/pipecat/__init__.py b/pipecat/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/pipecat/frames/__init__.py b/pipecat/frames/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/pipecat/frames/frames.proto b/pipecat/frames/frames.proto new file mode 100644 index 0000000000000000000000000000000000000000..70f3340227704c4ad1174713758802137e9e8086 --- /dev/null +++ b/pipecat/frames/frames.proto @@ -0,0 +1,43 @@ +// +// Copyright (c) 2024, Daily +// +// SPDX-License-Identifier: BSD 2-Clause License +// + +// Generate frames_pb2.py with: +// +// python -m grpc_tools.protoc --proto_path=./ --python_out=./protobufs frames.proto + +syntax = "proto3"; + +package pipecat; + +message TextFrame { + uint64 id = 1; + string name = 2; + string text = 3; +} + +message AudioRawFrame { + uint64 id = 1; + string name = 2; + bytes audio = 3; + uint32 sample_rate = 4; + uint32 num_channels = 5; +} + +message TranscriptionFrame { + uint64 id = 1; + string name = 2; + string text = 3; + string user_id = 4; + string timestamp = 5; +} + +message Frame { + oneof frame { + TextFrame text = 1; + AudioRawFrame audio = 2; + TranscriptionFrame transcription = 3; + } +} diff --git a/pipecat/frames/frames.py b/pipecat/frames/frames.py new file mode 100644 index 0000000000000000000000000000000000000000..86d47625f8c981077c627bd7a98f334aa4420f36 --- /dev/null +++ b/pipecat/frames/frames.py @@ -0,0 +1,340 @@ +# +# Copyright (c) 2024, Daily +# +# SPDX-License-Identifier: BSD 2-Clause License +# + +from typing import Any, List, Mapping, Tuple + +from dataclasses import dataclass, field + +from pipecat.utils.utils import obj_count, obj_id + + +@dataclass +class Frame: + id: int = field(init=False) + name: str = field(init=False) + + def __post_init__(self): + self.id: int = obj_id() + self.name: str = f"{self.__class__.__name__}#{obj_count(self)}" + + def __str__(self): + return self.name + + +@dataclass +class DataFrame(Frame): + pass + + +@dataclass +class AudioRawFrame(DataFrame): + """A chunk of audio. Will be played by the transport if the transport's + microphone has been enabled. + + """ + audio: bytes + sample_rate: int + num_channels: int + + def __post_init__(self): + super().__post_init__() + self.num_frames = int(len(self.audio) / (self.num_channels * 2)) + + def __str__(self): + return f"{self.name}(size: {len(self.audio)}, frames: {self.num_frames}, sample_rate: {self.sample_rate}, channels: {self.num_channels})" + + +@dataclass +class ImageRawFrame(DataFrame): + """An image. Will be shown by the transport if the transport's camera is + enabled. + + """ + image: bytes + size: Tuple[int, int] + format: str | None + + def __str__(self): + return f"{self.name}(size: {self.size}, format: {self.format})" + + +@dataclass +class URLImageRawFrame(ImageRawFrame): + """An image with an associated URL. Will be shown by the transport if the + transport's camera is enabled. + + """ + url: str | None + + def __str__(self): + return f"{self.name}(url: {self.url}, size: {self.size}, format: {self.format})" + + +@dataclass +class VisionImageRawFrame(ImageRawFrame): + """An image with an associated text to ask for a description of it. Will be + shown by the transport if the transport's camera is enabled. + + """ + text: str | None + + def __str__(self): + return f"{self.name}(text: {self.text}, size: {self.size}, format: {self.format})" + + +@dataclass +class UserImageRawFrame(ImageRawFrame): + """An image associated to a user. Will be shown by the transport if the + transport's camera is enabled. + + """ + user_id: str + + def __str__(self): + return f"{self.name}(user: {self.user_id}, size: {self.size}, format: {self.format})" + + +@dataclass +class SpriteFrame(Frame): + """An animated sprite. Will be shown by the transport if the transport's + camera is enabled. Will play at the framerate specified in the transport's + `fps` constructor parameter. + + """ + images: List[ImageRawFrame] + + def __str__(self): + return f"{self.name}(size: {len(self.images)})" + + +@dataclass +class TextFrame(DataFrame): + """A chunk of text. Emitted by LLM services, consumed by TTS services, can + be used to send text through pipelines. + + """ + text: str + + def __str__(self): + return f"{self.name}(text: {self.text})" + + +@dataclass +class TranscriptionFrame(TextFrame): + """A text frame with transcription-specific data. Will be placed in the + transport's receive queue when a participant speaks. + + """ + user_id: str + timestamp: str + + def __str__(self): + return f"{self.name}(user: {self.user_id}, text: {self.text}, timestamp: {self.timestamp})" + + +@dataclass +class InterimTranscriptionFrame(TextFrame): + """A text frame with interim transcription-specific data. Will be placed in + the transport's receive queue when a participant speaks.""" + user_id: str + timestamp: str + + def __str__(self): + return f"{self.name}(user: {self.user_id}, text: {self.text}, timestamp: {self.timestamp})" + + +@dataclass +class LLMMessagesFrame(DataFrame): + """A frame containing a list of LLM messages. Used to signal that an LLM + service should run a chat completion and emit an LLMStartFrames, TextFrames + and an LLMEndFrame. Note that the messages property on this class is + mutable, and will be be updated by various ResponseAggregator frame + processors. + + """ + messages: List[dict] + + +@dataclass +class TransportMessageFrame(DataFrame): + message: Any + + def __str__(self): + return f"{self.name}(message: {self.message})" + +# +# App frames. Application user-defined frames. +# + + +@dataclass +class AppFrame(Frame): + pass + +# +# System frames +# + + +@dataclass +class SystemFrame(Frame): + pass + + +@dataclass +class StartFrame(SystemFrame): + """This is the first frame that should be pushed down a pipeline.""" + allow_interruptions: bool = False + enable_metrics: bool = False + report_only_initial_ttfb: bool = False + + +@dataclass +class CancelFrame(SystemFrame): + """Indicates that a pipeline needs to stop right away.""" + pass + + +@dataclass +class ErrorFrame(SystemFrame): + """This is used notify upstream that an error has occurred downstream the + pipeline.""" + error: str | None + + def __str__(self): + return f"{self.name}(error: {self.error})" + + +@dataclass +class StopTaskFrame(SystemFrame): + """Indicates that a pipeline task should be stopped. This should inform the + pipeline processors that they should stop pushing frames but that they + should be kept in a running state. + + """ + pass + + +@dataclass +class StartInterruptionFrame(SystemFrame): + """Emitted by VAD to indicate that a user has started speaking (i.e. is + interruption). This is similar to UserStartedSpeakingFrame except that it + should be pushed concurrently with other frames (so the order is not + guaranteed). + + """ + pass + + +@dataclass +class StopInterruptionFrame(SystemFrame): + """Emitted by VAD to indicate that a user has stopped speaking (i.e. no more + interruptions). This is similar to UserStoppedSpeakingFrame except that it + should be pushed concurrently with other frames (so the order is not + guaranteed). + + """ + pass + + +@dataclass +class MetricsFrame(SystemFrame): + """Emitted by processor that can compute metrics like latencies. + """ + ttfb: List[Mapping[str, Any]] | None = None + processing: List[Mapping[str, Any]] | None = None + +# +# Control frames +# + + +@dataclass +class ControlFrame(Frame): + pass + + +@dataclass +class EndFrame(ControlFrame): + """Indicates that a pipeline has ended and frame processors and pipelines + should be shut down. If the transport receives this frame, it will stop + sending frames to its output channel(s) and close all its threads. Note, + that this is a control frame, which means it will received in the order it + was sent (unline system frames). + + """ + pass + + +@dataclass +class LLMFullResponseStartFrame(ControlFrame): + """Used to indicate the beginning of a full LLM response. Following + LLMResponseStartFrame, TextFrame and LLMResponseEndFrame for each sentence + until a LLMFullResponseEndFrame.""" + pass + + +@dataclass +class LLMFullResponseEndFrame(ControlFrame): + """Indicates the end of a full LLM response.""" + pass + + +@dataclass +class LLMResponseStartFrame(ControlFrame): + """Used to indicate the beginning of an LLM response. Following TextFrames + are part of the LLM response until an LLMResponseEndFrame""" + pass + + +@dataclass +class LLMResponseEndFrame(ControlFrame): + """Indicates the end of an LLM response.""" + pass + + +@dataclass +class UserStartedSpeakingFrame(ControlFrame): + """Emitted by VAD to indicate that a user has started speaking. This can be + used for interruptions or other times when detecting that someone is + speaking is more important than knowing what they're saying (as you will + with a TranscriptionFrame) + + """ + pass + + +@dataclass +class UserStoppedSpeakingFrame(ControlFrame): + """Emitted by the VAD to indicate that a user stopped speaking.""" + pass + + +@dataclass +class TTSStartedFrame(ControlFrame): + """Used to indicate the beginning of a TTS response. Following + AudioRawFrames are part of the TTS response until an TTSEndFrame. These + frames can be used for aggregating audio frames in a transport to optimize + the size of frames sent to the session, without needing to control this in + the TTS service. + + """ + pass + + +@dataclass +class TTSStoppedFrame(ControlFrame): + """Indicates the end of a TTS response.""" + pass + + +@dataclass +class UserImageRequestFrame(ControlFrame): + """A frame user to request an image from the given user.""" + user_id: str + + def __str__(self): + return f"{self.name}, user: {self.user_id}" diff --git a/pipecat/frames/protobufs/frames_pb2.py b/pipecat/frames/protobufs/frames_pb2.py new file mode 100644 index 0000000000000000000000000000000000000000..6341339ec4994831a0d278de4611cbb5c1186c14 --- /dev/null +++ b/pipecat/frames/protobufs/frames_pb2.py @@ -0,0 +1,32 @@ +# -*- coding: utf-8 -*- +# Generated by the protocol buffer compiler. DO NOT EDIT! +# source: frames.proto +# Protobuf Python Version: 4.25.1 +"""Generated protocol buffer code.""" +from google.protobuf import descriptor as _descriptor +from google.protobuf import descriptor_pool as _descriptor_pool +from google.protobuf import symbol_database as _symbol_database +from google.protobuf.internal import builder as _builder +# @@protoc_insertion_point(imports) + +_sym_db = _symbol_database.Default() + + + + +DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x0c\x66rames.proto\x12\x07pipecat\"3\n\tTextFrame\x12\n\n\x02id\x18\x01 \x01(\x04\x12\x0c\n\x04name\x18\x02 \x01(\t\x12\x0c\n\x04text\x18\x03 \x01(\t\"c\n\rAudioRawFrame\x12\n\n\x02id\x18\x01 \x01(\x04\x12\x0c\n\x04name\x18\x02 \x01(\t\x12\r\n\x05\x61udio\x18\x03 \x01(\x0c\x12\x13\n\x0bsample_rate\x18\x04 \x01(\r\x12\x14\n\x0cnum_channels\x18\x05 \x01(\r\"`\n\x12TranscriptionFrame\x12\n\n\x02id\x18\x01 \x01(\x04\x12\x0c\n\x04name\x18\x02 \x01(\t\x12\x0c\n\x04text\x18\x03 \x01(\t\x12\x0f\n\x07user_id\x18\x04 \x01(\t\x12\x11\n\ttimestamp\x18\x05 \x01(\t\"\x93\x01\n\x05\x46rame\x12\"\n\x04text\x18\x01 \x01(\x0b\x32\x12.pipecat.TextFrameH\x00\x12\'\n\x05\x61udio\x18\x02 \x01(\x0b\x32\x16.pipecat.AudioRawFrameH\x00\x12\x34\n\rtranscription\x18\x03 \x01(\x0b\x32\x1b.pipecat.TranscriptionFrameH\x00\x42\x07\n\x05\x66rameb\x06proto3') + +_globals = globals() +_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals) +_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'frames_pb2', _globals) +if _descriptor._USE_C_DESCRIPTORS == False: + DESCRIPTOR._options = None + _globals['_TEXTFRAME']._serialized_start=25 + _globals['_TEXTFRAME']._serialized_end=76 + _globals['_AUDIORAWFRAME']._serialized_start=78 + _globals['_AUDIORAWFRAME']._serialized_end=177 + _globals['_TRANSCRIPTIONFRAME']._serialized_start=179 + _globals['_TRANSCRIPTIONFRAME']._serialized_end=275 + _globals['_FRAME']._serialized_start=278 + _globals['_FRAME']._serialized_end=425 +# @@protoc_insertion_point(module_scope) diff --git a/pipecat/pipeline/__init__.py b/pipecat/pipeline/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/pipecat/pipeline/base_pipeline.py b/pipecat/pipeline/base_pipeline.py new file mode 100644 index 0000000000000000000000000000000000000000..4a1b62324f77f20742558cb23363aeb3216300a8 --- /dev/null +++ b/pipecat/pipeline/base_pipeline.py @@ -0,0 +1,21 @@ +# +# Copyright (c) 2024, Daily +# +# SPDX-License-Identifier: BSD 2-Clause License +# + +from abc import abstractmethod + +from typing import List + +from pipecat.processors.frame_processor import FrameProcessor + + +class BasePipeline(FrameProcessor): + + def __init__(self): + super().__init__() + + @abstractmethod + def processors_with_metrics(self) -> List[FrameProcessor]: + pass diff --git a/pipecat/pipeline/merge_pipeline.py b/pipecat/pipeline/merge_pipeline.py new file mode 100644 index 0000000000000000000000000000000000000000..2478fce430816b53a616ecad6839dd50c4a99526 --- /dev/null +++ b/pipecat/pipeline/merge_pipeline.py @@ -0,0 +1,24 @@ +from typing import List +from pipecat.pipeline.frames import EndFrame, EndPipeFrame +from pipecat.pipeline.pipeline import Pipeline + + +class SequentialMergePipeline(Pipeline): + """This class merges the sink queues from a list of pipelines. Frames from + each pipeline's sink are merged in the order of pipelines in the list.""" + + def __init__(self, pipelines: List[Pipeline]): + super().__init__([]) + self.pipelines = pipelines + + async def run_pipeline(self): + for idx, pipeline in enumerate(self.pipelines): + while True: + frame = await pipeline.sink.get() + if isinstance( + frame, EndFrame) or isinstance( + frame, EndPipeFrame): + break + await self.sink.put(frame) + + await self.sink.put(EndFrame()) diff --git a/pipecat/pipeline/parallel_pipeline.py b/pipecat/pipeline/parallel_pipeline.py new file mode 100644 index 0000000000000000000000000000000000000000..d608cd2b84044b4464ef42f23fe1d02b1f8b7214 --- /dev/null +++ b/pipecat/pipeline/parallel_pipeline.py @@ -0,0 +1,154 @@ +# +# Copyright (c) 2024, Daily +# +# SPDX-License-Identifier: BSD 2-Clause License +# + +import asyncio + +from itertools import chain +from typing import List + +from pipecat.pipeline.base_pipeline import BasePipeline +from pipecat.pipeline.pipeline import Pipeline +from pipecat.processors.frame_processor import FrameDirection, FrameProcessor +from pipecat.frames.frames import CancelFrame, EndFrame, Frame, StartFrame + +from loguru import logger + + +class Source(FrameProcessor): + + def __init__(self, upstream_queue: asyncio.Queue): + super().__init__() + self._up_queue = upstream_queue + + async def process_frame(self, frame: Frame, direction: FrameDirection): + await super().process_frame(frame, direction) + + match direction: + case FrameDirection.UPSTREAM: + await self._up_queue.put(frame) + case FrameDirection.DOWNSTREAM: + await self.push_frame(frame, direction) + + +class Sink(FrameProcessor): + + def __init__(self, downstream_queue: asyncio.Queue): + super().__init__() + self._down_queue = downstream_queue + + async def process_frame(self, frame: Frame, direction: FrameDirection): + await super().process_frame(frame, direction) + + match direction: + case FrameDirection.UPSTREAM: + await self.push_frame(frame, direction) + case FrameDirection.DOWNSTREAM: + await self._down_queue.put(frame) + + +class ParallelPipeline(BasePipeline): + def __init__(self, *args): + super().__init__() + + if len(args) == 0: + raise Exception(f"ParallelPipeline needs at least one argument") + + self._sources = [] + self._sinks = [] + + self._up_queue = asyncio.Queue() + self._down_queue = asyncio.Queue() + self._up_task: asyncio.Task | None = None + self._down_task: asyncio.Task | None = None + + self._pipelines = [] + + logger.debug(f"Creating {self} pipelines") + for processors in args: + if not isinstance(processors, list): + raise TypeError(f"ParallelPipeline argument {processors} is not a list") + + # We will add a source before the pipeline and a sink after. + source = Source(self._up_queue) + sink = Sink(self._down_queue) + self._sources.append(source) + self._sinks.append(sink) + + # Create pipeline + pipeline = Pipeline(processors) + source.link(pipeline) + pipeline.link(sink) + self._pipelines.append(pipeline) + + logger.debug(f"Finished creating {self} pipelines") + + # + # BasePipeline + # + + def processors_with_metrics(self) -> List[FrameProcessor]: + return list(chain.from_iterable(p.processors_with_metrics() for p in self._pipelines)) + + # + # Frame processor + # + + async def cleanup(self): + await asyncio.gather(*[p.cleanup() for p in self._pipelines]) + + async def _start_tasks(self): + loop = self.get_event_loop() + self._up_task = loop.create_task(self._process_up_queue()) + self._down_task = loop.create_task(self._process_down_queue()) + + async def process_frame(self, frame: Frame, direction: FrameDirection): + await super().process_frame(frame, direction) + + if isinstance(frame, StartFrame): + await self._start_tasks() + + if direction == FrameDirection.UPSTREAM: + # If we get an upstream frame we process it in each sink. + await asyncio.gather(*[s.process_frame(frame, direction) for s in self._sinks]) + elif direction == FrameDirection.DOWNSTREAM: + # If we get a downstream frame we process it in each source. + # TODO(aleix): We are creating task for each frame. For real-time + # video/audio this might be too slow. We should use an already + # created task instead. + await asyncio.gather(*[s.process_frame(frame, direction) for s in self._sources]) + + # If we get an EndFrame we stop our queue processing tasks and wait on + # all the pipelines to finish. + if isinstance(frame, CancelFrame) or isinstance(frame, EndFrame): + # Use None to indicate when queues should be done processing. + await self._up_queue.put(None) + await self._down_queue.put(None) + if self._up_task: + await self._up_task + if self._down_task: + await self._down_task + + async def _process_up_queue(self): + running = True + seen_ids = set() + while running: + frame = await self._up_queue.get() + if frame and frame.id not in seen_ids: + await self.push_frame(frame, FrameDirection.UPSTREAM) + seen_ids.add(frame.id) + running = frame is not None + self._up_queue.task_done() + + async def _process_down_queue(self): + running = True + seen_ids = set() + while running: + frame = await self._down_queue.get() + if frame and frame.id not in seen_ids: + await self.push_frame(frame, FrameDirection.DOWNSTREAM) + seen_ids.add(frame.id) + running = frame is not None + self._down_queue.task_done() diff --git a/pipecat/pipeline/parallel_task.py b/pipecat/pipeline/parallel_task.py new file mode 100644 index 0000000000000000000000000000000000000000..f6070875613156e2613148eba27f1150bcec23a3 --- /dev/null +++ b/pipecat/pipeline/parallel_task.py @@ -0,0 +1,119 @@ +# +# Copyright (c) 2024, Daily +# +# SPDX-License-Identifier: BSD 2-Clause License +# + +import asyncio + +from itertools import chain +from typing import List + +from pipecat.pipeline.base_pipeline import BasePipeline +from pipecat.pipeline.pipeline import Pipeline +from pipecat.processors.frame_processor import FrameDirection, FrameProcessor +from pipecat.frames.frames import Frame + +from loguru import logger + + +class Source(FrameProcessor): + + def __init__(self, upstream_queue: asyncio.Queue): + super().__init__() + self._up_queue = upstream_queue + + async def process_frame(self, frame: Frame, direction: FrameDirection): + await super().process_frame(frame, direction) + + match direction: + case FrameDirection.UPSTREAM: + await self._up_queue.put(frame) + case FrameDirection.DOWNSTREAM: + await self.push_frame(frame, direction) + + +class Sink(FrameProcessor): + + def __init__(self, downstream_queue: asyncio.Queue): + super().__init__() + self._down_queue = downstream_queue + + async def process_frame(self, frame: Frame, direction: FrameDirection): + await super().process_frame(frame, direction) + + match direction: + case FrameDirection.UPSTREAM: + await self.push_frame(frame, direction) + case FrameDirection.DOWNSTREAM: + await self._down_queue.put(frame) + + +class ParallelTask(BasePipeline): + def __init__(self, *args): + super().__init__() + + if len(args) == 0: + raise Exception(f"ParallelTask needs at least one argument") + + self._sinks = [] + self._pipelines = [] + + self._up_queue = asyncio.Queue() + self._down_queue = asyncio.Queue() + + logger.debug(f"Creating {self} pipelines") + for processors in args: + if not isinstance(processors, list): + raise TypeError(f"ParallelTask argument {processors} is not a list") + + # We add a source at the beginning of the pipeline and a sink at the end. + source = Source(self._up_queue) + sink = Sink(self._down_queue) + processors: List[FrameProcessor] = [source] + processors + processors.append(sink) + + # Keep track of sinks. We access the source through the pipeline. + self._sinks.append(sink) + + # Create pipeline + pipeline = Pipeline(processors) + self._pipelines.append(pipeline) + logger.debug(f"Finished creating {self} pipelines") + + # + # BasePipeline + # + + def processors_with_metrics(self) -> List[FrameProcessor]: + return list(chain.from_iterable(p.processors_with_metrics() for p in self._pipelines)) + + # + # Frame processor + # + + async def process_frame(self, frame: Frame, direction: FrameDirection): + await super().process_frame(frame, direction) + + if direction == FrameDirection.UPSTREAM: + # If we get an upstream frame we process it in each sink. + await asyncio.gather(*[s.process_frame(frame, direction) for s in self._sinks]) + elif direction == FrameDirection.DOWNSTREAM: + # If we get a downstream frame we process it in each source (using the pipeline). + await asyncio.gather(*[p.process_frame(frame, direction) for p in self._pipelines]) + + seen_ids = set() + while not self._up_queue.empty(): + frame = await self._up_queue.get() + if frame and frame.id not in seen_ids: + await self.push_frame(frame, FrameDirection.UPSTREAM) + seen_ids.add(frame.id) + self._up_queue.task_done() + + seen_ids = set() + while not self._down_queue.empty(): + frame = await self._down_queue.get() + if frame and frame.id not in seen_ids: + await self.push_frame(frame, FrameDirection.DOWNSTREAM) + seen_ids.add(frame.id) + self._down_queue.task_done() diff --git a/pipecat/pipeline/pipeline.py b/pipecat/pipeline/pipeline.py new file mode 100644 index 0000000000000000000000000000000000000000..88dbbfd6196fc3fee021dbbb099598ad42758d19 --- /dev/null +++ b/pipecat/pipeline/pipeline.py @@ -0,0 +1,95 @@ +# +# Copyright (c) 2024, Daily +# +# SPDX-License-Identifier: BSD 2-Clause License +# + +from typing import Callable, Coroutine, List + +from pipecat.frames.frames import Frame +from pipecat.pipeline.base_pipeline import BasePipeline +from pipecat.processors.frame_processor import FrameDirection, FrameProcessor + + +class PipelineSource(FrameProcessor): + + def __init__(self, upstream_push_frame: Callable[[Frame, FrameDirection], Coroutine]): + super().__init__() + self._upstream_push_frame = upstream_push_frame + + async def process_frame(self, frame: Frame, direction: FrameDirection): + await super().process_frame(frame, direction) + + match direction: + case FrameDirection.UPSTREAM: + await self._upstream_push_frame(frame, direction) + case FrameDirection.DOWNSTREAM: + await self.push_frame(frame, direction) + + +class PipelineSink(FrameProcessor): + + def __init__(self, downstream_push_frame: Callable[[Frame, FrameDirection], Coroutine]): + super().__init__() + self._downstream_push_frame = downstream_push_frame + + async def process_frame(self, frame: Frame, direction: FrameDirection): + await super().process_frame(frame, direction) + + match direction: + case FrameDirection.UPSTREAM: + await self.push_frame(frame, direction) + case FrameDirection.DOWNSTREAM: + await self._downstream_push_frame(frame, direction) + + +class Pipeline(BasePipeline): + + def __init__(self, processors: List[FrameProcessor]): + super().__init__() + + # Add a source and a sink queue so we can forward frames upstream and + # downstream outside of the pipeline. + self._source = PipelineSource(self.push_frame) + self._sink = PipelineSink(self.push_frame) + self._processors: List[FrameProcessor] = [self._source] + processors + [self._sink] + + self._link_processors() + + # + # BasePipeline + # + + def processors_with_metrics(self): + services = [] + for p in self._processors: + if isinstance(p, BasePipeline): + services += p.processors_with_metrics() + elif p.can_generate_metrics(): + services.append(p) + return services + + # + # Frame processor + # + + async def cleanup(self): + await self._cleanup_processors() + + async def process_frame(self, frame: Frame, direction: FrameDirection): + await super().process_frame(frame, direction) + + if direction == FrameDirection.DOWNSTREAM: + await self._source.process_frame(frame, FrameDirection.DOWNSTREAM) + elif direction == FrameDirection.UPSTREAM: + await self._sink.process_frame(frame, FrameDirection.UPSTREAM) + + async def _cleanup_processors(self): + for p in self._processors: + await p.cleanup() + + def _link_processors(self): + prev = self._processors[0] + for curr in self._processors[1:]: + prev.link(curr) + prev = curr diff --git a/pipecat/pipeline/runner.py b/pipecat/pipeline/runner.py new file mode 100644 index 0000000000000000000000000000000000000000..6d0003a33b7c47af7f969733ca132e1b1c6e4af6 --- /dev/null +++ b/pipecat/pipeline/runner.py @@ -0,0 +1,58 @@ +# +# Copyright (c) 2024, Daily +# +# SPDX-License-Identifier: BSD 2-Clause License +# + +import asyncio +import signal + +from pipecat.pipeline.task import PipelineTask +from pipecat.utils.utils import obj_count, obj_id + +from loguru import logger + + +class PipelineRunner: + + def __init__(self, *, name: str | None = None, handle_sigint: bool = True): + self.id: int = obj_id() + self.name: str = name or f"{self.__class__.__name__}#{obj_count(self)}" + + self._tasks = {} + + if handle_sigint: + self._setup_sigint() + + async def run(self, task: PipelineTask): + logger.debug(f"Runner {self} started running {task}") + self._tasks[task.name] = task + await task.run() + del self._tasks[task.name] + logger.debug(f"Runner {self} finished running {task}") + + async def stop_when_done(self): + logger.debug(f"Runner {self} scheduled to stop when all tasks are done") + await asyncio.gather(*[t.stop_when_done() for t in self._tasks.values()]) + + async def cancel(self): + logger.debug(f"Canceling runner {self}") + await asyncio.gather(*[t.cancel() for t in self._tasks.values()]) + + def _setup_sigint(self): + loop = asyncio.get_running_loop() + loop.add_signal_handler( + signal.SIGINT, + lambda *args: asyncio.create_task(self._sig_handler()) + ) + loop.add_signal_handler( + signal.SIGTERM, + lambda *args: asyncio.create_task(self._sig_handler()) + ) + + async def _sig_handler(self): + logger.warning(f"Interruption detected. Canceling runner {self}") + await self.cancel() + + def __str__(self): + return self.name diff --git a/pipecat/pipeline/task.py b/pipecat/pipeline/task.py new file mode 100644 index 0000000000000000000000000000000000000000..1fa9cff4993b21f999e284c17e783a36c886702c --- /dev/null +++ b/pipecat/pipeline/task.py @@ -0,0 +1,142 @@ +# +# Copyright (c) 2024, Daily +# +# SPDX-License-Identifier: BSD 2-Clause License +# + +import asyncio + +from typing import AsyncIterable, Iterable + +from pydantic import BaseModel + +from pipecat.frames.frames import CancelFrame, EndFrame, ErrorFrame, Frame, MetricsFrame, StartFrame, StopTaskFrame +from pipecat.pipeline.base_pipeline import BasePipeline +from pipecat.processors.frame_processor import FrameDirection, FrameProcessor +from pipecat.utils.utils import obj_count, obj_id + +from loguru import logger + + +class PipelineParams(BaseModel): + allow_interruptions: bool = False + enable_metrics: bool = False + report_only_initial_ttfb: bool = False + + +class Source(FrameProcessor): + + def __init__(self, up_queue: asyncio.Queue): + super().__init__() + self._up_queue = up_queue + + async def process_frame(self, frame: Frame, direction: FrameDirection): + await super().process_frame(frame, direction) + + match direction: + case FrameDirection.UPSTREAM: + await self._up_queue.put(frame) + case FrameDirection.DOWNSTREAM: + await self.push_frame(frame, direction) + + +class PipelineTask: + + def __init__(self, pipeline: BasePipeline, params: PipelineParams = PipelineParams()): + self.id: int = obj_id() + self.name: str = f"{self.__class__.__name__}#{obj_count(self)}" + + self._pipeline = pipeline + self._params = params + self._finished = False + + self._down_queue = asyncio.Queue() + self._up_queue = asyncio.Queue() + + self._source = Source(self._up_queue) + self._source.link(pipeline) + + def has_finished(self): + return self._finished + + async def stop_when_done(self): + logger.debug(f"Task {self} scheduled to stop when done") + await self.queue_frame(EndFrame()) + + async def cancel(self): + logger.debug(f"Canceling pipeline task {self}") + # Make sure everything is cleaned up downstream. This is sent + # out-of-band from the main streaming task which is what we want since + # we want to cancel right away. + await self._source.process_frame(CancelFrame(), FrameDirection.DOWNSTREAM) + self._process_down_task.cancel() + self._process_up_task.cancel() + await self._process_down_task + await self._process_up_task + + async def run(self): + self._process_up_task = asyncio.create_task(self._process_up_queue()) + self._process_down_task = asyncio.create_task(self._process_down_queue()) + await asyncio.gather(self._process_up_task, self._process_down_task) + self._finished = True + + async def queue_frame(self, frame: Frame): + await self._down_queue.put(frame) + + async def queue_frames(self, frames: Iterable[Frame] | AsyncIterable[Frame]): + if isinstance(frames, AsyncIterable): + async for frame in frames: + await self.queue_frame(frame) + elif isinstance(frames, Iterable): + for frame in frames: + await self.queue_frame(frame) + else: + raise Exception("Frames must be an iterable or async iterable") + + def _initial_metrics_frame(self) -> MetricsFrame: + processors = self._pipeline.processors_with_metrics() + ttfb = [{"name": p.name, "time": 0.0} for p in processors] + processing = [{"name": p.name, "time": 0.0} for p in processors] + return MetricsFrame(ttfb=ttfb, processing=processing) + + async def _process_down_queue(self): + start_frame = StartFrame( + allow_interruptions=self._params.allow_interruptions, + enable_metrics=self._params.enable_metrics, + report_only_initial_ttfb=self._params.report_only_initial_ttfb + ) + await self._source.process_frame(start_frame, FrameDirection.DOWNSTREAM) + await self._source.process_frame(self._initial_metrics_frame(), FrameDirection.DOWNSTREAM) + + running = True + should_cleanup = True + while running: + try: + frame = await self._down_queue.get() + await self._source.process_frame(frame, FrameDirection.DOWNSTREAM) + running = not (isinstance(frame, StopTaskFrame) or isinstance(frame, EndFrame)) + should_cleanup = not isinstance(frame, StopTaskFrame) + self._down_queue.task_done() + except asyncio.CancelledError: + break + # Cleanup only if we need to. + if should_cleanup: + await self._source.cleanup() + await self._pipeline.cleanup() + # We just enqueue None to terminate the task gracefully. + self._process_up_task.cancel() + await self._process_up_task + + async def _process_up_queue(self): + while True: + try: + frame = await self._up_queue.get() + if isinstance(frame, ErrorFrame): + logger.error(f"Error running app: {frame.error}") + await self.queue_frame(CancelFrame()) + self._up_queue.task_done() + except asyncio.CancelledError: + break + + def __str__(self): + return self.name diff --git a/pipecat/processors/__init__.py b/pipecat/processors/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/pipecat/processors/aggregators/__init__.py b/pipecat/processors/aggregators/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/pipecat/processors/aggregators/gated.py b/pipecat/processors/aggregators/gated.py new file mode 100644 index 0000000000000000000000000000000000000000..14398311073f5d1c81a7d93c91de410c1c7ed277 --- /dev/null +++ b/pipecat/processors/aggregators/gated.py @@ -0,0 +1,74 @@ +# +# Copyright (c) 2024, Daily +# +# SPDX-License-Identifier: BSD 2-Clause License +# + +from typing import List + +from pipecat.frames.frames import Frame, SystemFrame +from pipecat.processors.frame_processor import FrameDirection, FrameProcessor + +from loguru import logger + + +class GatedAggregator(FrameProcessor): + """Accumulate frames, with custom functions to start and stop accumulation. + Yields gate-opening frame before any accumulated frames, then ensuing frames + until and not including the gate-closed frame. + + >>> from pipecat.pipeline.frames import ImageFrame + + >>> async def print_frames(aggregator, frame): + ... async for frame in aggregator.process_frame(frame): + ... if isinstance(frame, TextFrame): + ... print(frame.text) + ... else: + ... print(frame.__class__.__name__) + + >>> aggregator = GatedAggregator( + ... gate_close_fn=lambda x: isinstance(x, LLMResponseStartFrame), + ... gate_open_fn=lambda x: isinstance(x, ImageFrame), + ... start_open=False) + >>> asyncio.run(print_frames(aggregator, TextFrame("Hello"))) + >>> asyncio.run(print_frames(aggregator, TextFrame("Hello again."))) + >>> asyncio.run(print_frames(aggregator, ImageFrame(image=bytes([]), size=(0, 0)))) + ImageFrame + Hello + Hello again. + >>> asyncio.run(print_frames(aggregator, TextFrame("Goodbye."))) + Goodbye. + """ + + def __init__(self, gate_open_fn, gate_close_fn, start_open): + super().__init__() + self._gate_open_fn = gate_open_fn + self._gate_close_fn = gate_close_fn + self._gate_open = start_open + self._accumulator: List[Frame] = [] + + async def process_frame(self, frame: Frame, direction: FrameDirection): + await super().process_frame(frame, direction) + + # We must not block system frames. + if isinstance(frame, SystemFrame): + await self.push_frame(frame, direction) + return + + old_state = self._gate_open + if self._gate_open: + self._gate_open = not self._gate_close_fn(frame) + else: + self._gate_open = self._gate_open_fn(frame) + + if old_state != self._gate_open: + state = "open" if self._gate_open else "closed" + logger.debug(f"Gate is now {state} because of {frame}") + + if self._gate_open: + await self.push_frame(frame, direction) + for frame in self._accumulator: + await self.push_frame(frame, direction) + self._accumulator = [] + else: + self._accumulator.append(frame) diff --git a/pipecat/processors/aggregators/llm_response.py b/pipecat/processors/aggregators/llm_response.py new file mode 100644 index 0000000000000000000000000000000000000000..1f8386566affb7f14194b68052be0eb296bc75e2 --- /dev/null +++ b/pipecat/processors/aggregators/llm_response.py @@ -0,0 +1,266 @@ +# +# Copyright (c) 2024, Daily +# +# SPDX-License-Identifier: BSD 2-Clause License +# + +from typing import List + +from pipecat.services.openai import OpenAILLMContextFrame, OpenAILLMContext + +from pipecat.processors.frame_processor import FrameDirection, FrameProcessor +from pipecat.frames.frames import ( + Frame, + InterimTranscriptionFrame, + LLMFullResponseEndFrame, + LLMFullResponseStartFrame, + LLMResponseEndFrame, + LLMResponseStartFrame, + LLMMessagesFrame, + StartInterruptionFrame, + TranscriptionFrame, + TextFrame, + UserStartedSpeakingFrame, + UserStoppedSpeakingFrame) + + +class LLMResponseAggregator(FrameProcessor): + + def __init__( + self, + *, + messages: List[dict], + role: str, + start_frame, + end_frame, + accumulator_frame: TextFrame, + interim_accumulator_frame: TextFrame | None = None, + handle_interruptions: bool = False + ): + super().__init__() + + self._messages = messages + self._role = role + self._start_frame = start_frame + self._end_frame = end_frame + self._accumulator_frame = accumulator_frame + self._interim_accumulator_frame = interim_accumulator_frame + self._handle_interruptions = handle_interruptions + + # Reset our accumulator state. + self._reset() + + @property + def messages(self): + return self._messages + + @property + def role(self): + return self._role + + # + # Frame processor + # + + # Use cases implemented: + # + # S: Start, E: End, T: Transcription, I: Interim, X: Text + # + # S E -> None + # S T E -> X + # S I T E -> X + # S I E T -> X + # S I E I T -> X + # S E T -> X + # S E I T -> X + # + # The following case would not be supported: + # + # S I E T1 I T2 -> X + # + # and T2 would be dropped. + + async def process_frame(self, frame: Frame, direction: FrameDirection): + await super().process_frame(frame, direction) + + send_aggregation = False + + if isinstance(frame, self._start_frame): + self._aggregation = "" + self._aggregating = True + self._seen_start_frame = True + self._seen_end_frame = False + self._seen_interim_results = False + await self.push_frame(frame, direction) + elif isinstance(frame, self._end_frame): + self._seen_end_frame = True + self._seen_start_frame = False + + # We might have received the end frame but we might still be + # aggregating (i.e. we have seen interim results but not the final + # text). + self._aggregating = self._seen_interim_results or len(self._aggregation) == 0 + + # Send the aggregation if we are not aggregating anymore (i.e. no + # more interim results received). + send_aggregation = not self._aggregating + await self.push_frame(frame, direction) + elif isinstance(frame, self._accumulator_frame): + if self._aggregating: + self._aggregation += f" {frame.text}" + # We have recevied a complete sentence, so if we have seen the + # end frame and we were still aggregating, it means we should + # send the aggregation. + send_aggregation = self._seen_end_frame + + # We just got our final result, so let's reset interim results. + self._seen_interim_results = False + elif self._interim_accumulator_frame and isinstance(frame, self._interim_accumulator_frame): + self._seen_interim_results = True + elif self._handle_interruptions and isinstance(frame, StartInterruptionFrame): + await self._push_aggregation() + # Reset anyways + self._reset() + await self.push_frame(frame, direction) + else: + await self.push_frame(frame, direction) + + if send_aggregation: + await self._push_aggregation() + + async def _push_aggregation(self): + if len(self._aggregation) > 0: + self._messages.append({"role": self._role, "content": self._aggregation}) + + # Reset the aggregation. Reset it before pushing it down, otherwise + # if the tasks gets cancelled we won't be able to clear things up. + self._aggregation = "" + + frame = LLMMessagesFrame(self._messages) + await self.push_frame(frame) + + def _reset(self): + self._aggregation = "" + self._aggregating = False + self._seen_start_frame = False + self._seen_end_frame = False + self._seen_interim_results = False + + +class LLMAssistantResponseAggregator(LLMResponseAggregator): + def __init__(self, messages: List[dict] = []): + super().__init__( + messages=messages, + role="assistant", + start_frame=LLMFullResponseStartFrame, + end_frame=LLMFullResponseEndFrame, + accumulator_frame=TextFrame, + handle_interruptions=True + ) + + +class LLMUserResponseAggregator(LLMResponseAggregator): + def __init__(self, messages: List[dict] = []): + super().__init__( + messages=messages, + role="user", + start_frame=UserStartedSpeakingFrame, + end_frame=UserStoppedSpeakingFrame, + accumulator_frame=TranscriptionFrame, + interim_accumulator_frame=InterimTranscriptionFrame + ) + + +class LLMFullResponseAggregator(FrameProcessor): + """This class aggregates Text frames until it receives a + LLMResponseEndFrame, then emits the concatenated text as + a single text frame. + + given the following frames: + + TextFrame("Hello,") + TextFrame(" world.") + TextFrame(" I am") + TextFrame(" an LLM.") + LLMResponseEndFrame()] + + this processor will yield nothing for the first 4 frames, then + + TextFrame("Hello, world. I am an LLM.") + LLMResponseEndFrame() + + when passed the last frame. + + >>> async def print_frames(aggregator, frame): + ... async for frame in aggregator.process_frame(frame): + ... if isinstance(frame, TextFrame): + ... print(frame.text) + ... else: + ... print(frame.__class__.__name__) + + >>> aggregator = LLMFullResponseAggregator() + >>> asyncio.run(print_frames(aggregator, TextFrame("Hello,"))) + >>> asyncio.run(print_frames(aggregator, TextFrame(" world."))) + >>> asyncio.run(print_frames(aggregator, TextFrame(" I am"))) + >>> asyncio.run(print_frames(aggregator, TextFrame(" an LLM."))) + >>> asyncio.run(print_frames(aggregator, LLMResponseEndFrame())) + Hello, world. I am an LLM. + LLMResponseEndFrame + """ + + def __init__(self): + super().__init__() + self._aggregation = "" + + async def process_frame(self, frame: Frame, direction: FrameDirection): + await super().process_frame(frame, direction) + + if isinstance(frame, TextFrame): + self._aggregation += frame.text + elif isinstance(frame, LLMFullResponseEndFrame): + await self.push_frame(TextFrame(self._aggregation)) + await self.push_frame(frame) + self._aggregation = "" + else: + await self.push_frame(frame, direction) + + +class LLMContextAggregator(LLMResponseAggregator): + def __init__(self, *, context: OpenAILLMContext, **kwargs): + + self._context = context + super().__init__(**kwargs) + + async def _push_aggregation(self): + if len(self._aggregation) > 0: + self._context.add_message({"role": self._role, "content": self._aggregation}) + frame = OpenAILLMContextFrame(self._context) + await self.push_frame(frame) + + # Reset our accumulator state. + self._reset() + + +class LLMAssistantContextAggregator(LLMContextAggregator): + def __init__(self, context: OpenAILLMContext): + super().__init__( + messages=[], + context=context, + role="assistant", + start_frame=LLMResponseStartFrame, + end_frame=LLMResponseEndFrame, + accumulator_frame=TextFrame + ) + + +class LLMUserContextAggregator(LLMContextAggregator): + def __init__(self, context: OpenAILLMContext): + super().__init__( + messages=[], + context=context, + role="user", + start_frame=UserStartedSpeakingFrame, + end_frame=UserStoppedSpeakingFrame, + accumulator_frame=TranscriptionFrame, + interim_accumulator_frame=InterimTranscriptionFrame + ) diff --git a/pipecat/processors/aggregators/openai_llm_context.py b/pipecat/processors/aggregators/openai_llm_context.py new file mode 100644 index 0000000000000000000000000000000000000000..e23923bc6e22b1a865ecf5eed87070fda7fc2b5c --- /dev/null +++ b/pipecat/processors/aggregators/openai_llm_context.py @@ -0,0 +1,114 @@ +# +# Copyright (c) 2024, Daily +# +# SPDX-License-Identifier: BSD 2-Clause License +# + +from dataclasses import dataclass +import io +import json + +from typing import List + +from PIL import Image + +from pipecat.frames.frames import Frame, VisionImageRawFrame + +from openai._types import NOT_GIVEN, NotGiven + +from openai.types.chat import ( + ChatCompletionToolParam, + ChatCompletionToolChoiceOptionParam, + ChatCompletionMessageParam +) + +# JSON custom encoder to handle bytes arrays so that we can log contexts +# with images to the console. + + +class CustomEncoder(json.JSONEncoder): + def default(self, obj): + if isinstance(obj, io.BytesIO): + # Convert the first 8 bytes to an ASCII hex string + return (f"{obj.getbuffer()[0:8].hex()}...") + return super().default(obj) + + +class OpenAILLMContext: + + def __init__( + self, + messages: List[ChatCompletionMessageParam] | None = None, + tools: List[ChatCompletionToolParam] | NotGiven = NOT_GIVEN, + tool_choice: ChatCompletionToolChoiceOptionParam | NotGiven = NOT_GIVEN + ): + self.messages: List[ChatCompletionMessageParam] = messages if messages else [ + ] + self.tool_choice: ChatCompletionToolChoiceOptionParam | NotGiven = tool_choice + self.tools: List[ChatCompletionToolParam] | NotGiven = tools + + @staticmethod + def from_messages(messages: List[dict]) -> "OpenAILLMContext": + context = OpenAILLMContext() + for message in messages: + context.add_message({ + "content": message["content"], + "role": message["role"], + "name": message["name"] if "name" in message else message["role"] + }) + return context + + @staticmethod + def from_image_frame(frame: VisionImageRawFrame) -> "OpenAILLMContext": + """ + For images, we are deviating from the OpenAI messages shape. OpenAI + expects images to be base64 encoded, but other vision models may not. + So we'll store the image as bytes and do the base64 encoding as needed + in the LLM service. + """ + context = OpenAILLMContext() + buffer = io.BytesIO() + Image.frombytes( + frame.format, + frame.size, + frame.image + ).save( + buffer, + format="JPEG") + context.add_message({ + "content": frame.text, + "role": "user", + "data": buffer, + "mime_type": "image/jpeg" + }) + return context + + def add_message(self, message: ChatCompletionMessageParam): + self.messages.append(message) + + def get_messages(self) -> List[ChatCompletionMessageParam]: + return self.messages + + def get_messages_json(self) -> str: + return json.dumps(self.messages, cls=CustomEncoder) + + def set_tool_choice( + self, tool_choice: ChatCompletionToolChoiceOptionParam | NotGiven + ): + self.tool_choice = tool_choice + + def set_tools(self, tools: List[ChatCompletionToolParam] | NotGiven = NOT_GIVEN): + if tools != NOT_GIVEN and len(tools) == 0: + tools = NOT_GIVEN + + self.tools = tools + + +@dataclass +class OpenAILLMContextFrame(Frame): + """Like an LLMMessagesFrame, but with extra context specific to the OpenAI + API. The context in this message is also mutable, and will be changed by the + OpenAIContextAggregator frame processor. + + """ + context: OpenAILLMContext diff --git a/pipecat/processors/aggregators/sentence.py b/pipecat/processors/aggregators/sentence.py new file mode 100644 index 0000000000000000000000000000000000000000..5a3bab7b8974adfd35cbd80c12b4590cee245d3f --- /dev/null +++ b/pipecat/processors/aggregators/sentence.py @@ -0,0 +1,54 @@ +# +# Copyright (c) 2024, Daily +# +# SPDX-License-Identifier: BSD 2-Clause License +# + +import re + +from pipecat.frames.frames import EndFrame, Frame, InterimTranscriptionFrame, TextFrame +from pipecat.processors.frame_processor import FrameDirection, FrameProcessor + + +class SentenceAggregator(FrameProcessor): + """This frame processor aggregates text frames into complete sentences. + + Frame input/output: + TextFrame("Hello,") -> None + TextFrame(" world.") -> TextFrame("Hello world.") + + Doctest: + >>> async def print_frames(aggregator, frame): + ... async for frame in aggregator.process_frame(frame): + ... print(frame.text) + + >>> aggregator = SentenceAggregator() + >>> asyncio.run(print_frames(aggregator, TextFrame("Hello,"))) + >>> asyncio.run(print_frames(aggregator, TextFrame(" world."))) + Hello, world. + """ + + def __init__(self): + super().__init__() + self._aggregation = "" + + async def process_frame(self, frame: Frame, direction: FrameDirection): + await super().process_frame(frame, direction) + + # We ignore interim description at this point. + if isinstance(frame, InterimTranscriptionFrame): + return + + if isinstance(frame, TextFrame): + m = re.search("(.*[?.!])(.*)", frame.text) + if m: + await self.push_frame(TextFrame(self._aggregation + m.group(1))) + self._aggregation = m.group(2) + else: + self._aggregation += frame.text + elif isinstance(frame, EndFrame): + if self._aggregation: + await self.push_frame(TextFrame(self._aggregation)) + await self.push_frame(frame) + else: + await self.push_frame(frame, direction) diff --git a/pipecat/processors/aggregators/user_response.py b/pipecat/processors/aggregators/user_response.py new file mode 100644 index 0000000000000000000000000000000000000000..900ab6d93a6ca045f08688eeff9a40051d221753 --- /dev/null +++ b/pipecat/processors/aggregators/user_response.py @@ -0,0 +1,156 @@ +# +# Copyright (c) 2024, Daily +# +# SPDX-License-Identifier: BSD 2-Clause License +# + +from pipecat.processors.frame_processor import FrameDirection, FrameProcessor +from pipecat.frames.frames import ( + Frame, + InterimTranscriptionFrame, + StartInterruptionFrame, + TextFrame, + TranscriptionFrame, + UserStartedSpeakingFrame, + UserStoppedSpeakingFrame) + + +class ResponseAggregator(FrameProcessor): + """This frame processor aggregates frames between a start and an end frame + into complete text frame sentences. + + For example, frame input/output: + UserStartedSpeakingFrame() -> None + TranscriptionFrame("Hello,") -> None + TranscriptionFrame(" world.") -> None + UserStoppedSpeakingFrame() -> TextFrame("Hello world.") + + Doctest: + >>> async def print_frames(aggregator, frame): + ... async for frame in aggregator.process_frame(frame): + ... if isinstance(frame, TextFrame): + ... print(frame.text) + + >>> aggregator = ResponseAggregator(start_frame = UserStartedSpeakingFrame, + ... end_frame=UserStoppedSpeakingFrame, + ... accumulator_frame=TranscriptionFrame, + ... pass_through=False) + >>> asyncio.run(print_frames(aggregator, UserStartedSpeakingFrame())) + >>> asyncio.run(print_frames(aggregator, TranscriptionFrame("Hello,", 1, 1))) + >>> asyncio.run(print_frames(aggregator, TranscriptionFrame("world.", 1, 2))) + >>> asyncio.run(print_frames(aggregator, UserStoppedSpeakingFrame())) + Hello, world. + + """ + + def __init__( + self, + *, + start_frame, + end_frame, + accumulator_frame: TextFrame, + interim_accumulator_frame: TextFrame | None = None + ): + super().__init__() + + self._start_frame = start_frame + self._end_frame = end_frame + self._accumulator_frame = accumulator_frame + self._interim_accumulator_frame = interim_accumulator_frame + + # Reset our accumulator state. + self._reset() + + # + # Frame processor + # + + # Use cases implemented: + # + # S: Start, E: End, T: Transcription, I: Interim, X: Text + # + # S E -> None + # S T E -> X + # S I T E -> X + # S I E T -> X + # S I E I T -> X + # S E T -> X + # S E I T -> X + # + # The following case would not be supported: + # + # S I E T1 I T2 -> X + # + # and T2 would be dropped. + + async def process_frame(self, frame: Frame, direction: FrameDirection): + await super().process_frame(frame, direction) + + send_aggregation = False + + if isinstance(frame, self._start_frame): + self._aggregating = True + self._seen_start_frame = True + self._seen_end_frame = False + self._seen_interim_results = False + await self.push_frame(frame, direction) + elif isinstance(frame, self._end_frame): + self._seen_end_frame = True + self._seen_start_frame = False + + # We might have received the end frame but we might still be + # aggregating (i.e. we have seen interim results but not the final + # text). + self._aggregating = self._seen_interim_results or len(self._aggregation) == 0 + + # Send the aggregation if we are not aggregating anymore (i.e. no + # more interim results received). + send_aggregation = not self._aggregating + await self.push_frame(frame, direction) + elif isinstance(frame, self._accumulator_frame): + if self._aggregating: + self._aggregation += f" {frame.text}" + # We have recevied a complete sentence, so if we have seen the + # end frame and we were still aggregating, it means we should + # send the aggregation. + send_aggregation = self._seen_end_frame + + # We just got our final result, so let's reset interim results. + self._seen_interim_results = False + elif self._interim_accumulator_frame and isinstance(frame, self._interim_accumulator_frame): + self._seen_interim_results = True + else: + await self.push_frame(frame, direction) + + if send_aggregation: + await self._push_aggregation() + + async def _push_aggregation(self): + if len(self._aggregation) > 0: + frame = TextFrame(self._aggregation.strip()) + + # Reset the aggregation. Reset it before pushing it down, otherwise + # if the tasks gets cancelled we won't be able to clear things up. + self._aggregation = "" + + await self.push_frame(frame) + + # Reset our accumulator state. + self._reset() + + def _reset(self): + self._aggregation = "" + self._aggregating = False + self._seen_start_frame = False + self._seen_end_frame = False + self._seen_interim_results = False + + +class UserResponseAggregator(ResponseAggregator): + def __init__(self): + super().__init__( + start_frame=UserStartedSpeakingFrame, + end_frame=UserStoppedSpeakingFrame, + accumulator_frame=TranscriptionFrame, + interim_accumulator_frame=InterimTranscriptionFrame, + ) diff --git a/pipecat/processors/aggregators/vision_image_frame.py b/pipecat/processors/aggregators/vision_image_frame.py new file mode 100644 index 0000000000000000000000000000000000000000..698a50cb7c5b9320efbab2941a8191df57a8b895 --- /dev/null +++ b/pipecat/processors/aggregators/vision_image_frame.py @@ -0,0 +1,47 @@ +# +# Copyright (c) 2024, Daily +# +# SPDX-License-Identifier: BSD 2-Clause License +# + +from pipecat.frames.frames import Frame, ImageRawFrame, TextFrame, VisionImageRawFrame +from pipecat.processors.frame_processor import FrameDirection, FrameProcessor + + +class VisionImageFrameAggregator(FrameProcessor): + """This aggregator waits for a consecutive TextFrame and an + ImageFrame. After the ImageFrame arrives it will output a VisionImageFrame. + + >>> from pipecat.pipeline.frames import ImageFrame + + >>> async def print_frames(aggregator, frame): + ... async for frame in aggregator.process_frame(frame): + ... print(frame) + + >>> aggregator = VisionImageFrameAggregator() + >>> asyncio.run(print_frames(aggregator, TextFrame("What do you see?"))) + >>> asyncio.run(print_frames(aggregator, ImageFrame(image=bytes([]), size=(0, 0)))) + VisionImageFrame, text: What do you see?, image size: 0x0, buffer size: 0 B + + """ + + def __init__(self): + super().__init__() + self._describe_text = None + + async def process_frame(self, frame: Frame, direction: FrameDirection): + await super().process_frame(frame, direction) + + if isinstance(frame, TextFrame): + self._describe_text = frame.text + elif isinstance(frame, ImageRawFrame): + if self._describe_text: + frame = VisionImageRawFrame( + text=self._describe_text, + image=frame.image, + size=frame.size, + format=frame.format) + await self.push_frame(frame) + self._describe_text = None + else: + await self.push_frame(frame, direction) diff --git a/pipecat/processors/async_frame_processor.py b/pipecat/processors/async_frame_processor.py new file mode 100644 index 0000000000000000000000000000000000000000..b5aabba5771973245062a0a2301fd1a8ed6a15c1 --- /dev/null +++ b/pipecat/processors/async_frame_processor.py @@ -0,0 +1,63 @@ +# +# Copyright (c) 2024, Daily +# +# SPDX-License-Identifier: BSD 2-Clause License +# + +import asyncio + +from pipecat.frames.frames import EndFrame, Frame, StartInterruptionFrame +from pipecat.processors.frame_processor import FrameDirection, FrameProcessor + + +class AsyncFrameProcessor(FrameProcessor): + + def __init__( + self, + *, + name: str | None = None, + loop: asyncio.AbstractEventLoop | None = None, + **kwargs): + super().__init__(name=name, loop=loop, **kwargs) + + self._create_push_task() + + async def process_frame(self, frame: Frame, direction: FrameDirection): + await super().process_frame(frame, direction) + + if isinstance(frame, StartInterruptionFrame): + await self._handle_interruptions(frame) + + async def queue_frame( + self, + frame: Frame, + direction: FrameDirection = FrameDirection.DOWNSTREAM): + await self._push_queue.put((frame, direction)) + + async def cleanup(self): + self._push_frame_task.cancel() + await self._push_frame_task + + async def _handle_interruptions(self, frame: Frame): + # Cancel the task. This will stop pushing frames downstream. + self._push_frame_task.cancel() + await self._push_frame_task + # Push an out-of-band frame (i.e. not using the ordered push + # frame task). + await self.push_frame(frame) + # Create a new queue and task. + self._create_push_task() + + def _create_push_task(self): + self._push_queue = asyncio.Queue() + self._push_frame_task = self.get_event_loop().create_task(self._push_frame_task_handler()) + + async def _push_frame_task_handler(self): + running = True + while running: + try: + (frame, direction) = await self._push_queue.get() + await self.push_frame(frame, direction) + running = not isinstance(frame, EndFrame) + except asyncio.CancelledError: + break diff --git a/pipecat/processors/filters/__init__.py b/pipecat/processors/filters/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/pipecat/processors/filters/frame_filter.py b/pipecat/processors/filters/frame_filter.py new file mode 100644 index 0000000000000000000000000000000000000000..e04eb90f635bdb560bbfc210575e7ad0e385af7d --- /dev/null +++ b/pipecat/processors/filters/frame_filter.py @@ -0,0 +1,36 @@ +# +# Copyright (c) 2024, Daily +# +# SPDX-License-Identifier: BSD 2-Clause License +# + +from typing import List + +from pipecat.frames.frames import AppFrame, ControlFrame, Frame, SystemFrame +from pipecat.processors.frame_processor import FrameDirection, FrameProcessor + + +class FrameFilter(FrameProcessor): + + def __init__(self, types: List[type]): + super().__init__() + self._types = types + + # + # Frame processor + # + + def _should_passthrough_frame(self, frame): + for t in self._types: + if isinstance(frame, t): + return True + + return (isinstance(frame, AppFrame) + or isinstance(frame, ControlFrame) + or isinstance(frame, SystemFrame)) + + async def process_frame(self, frame: Frame, direction: FrameDirection): + await super().process_frame(frame, direction) + + if self._should_passthrough_frame(frame): + await self.push_frame(frame, direction) diff --git a/pipecat/processors/filters/function_filter.py b/pipecat/processors/filters/function_filter.py new file mode 100644 index 0000000000000000000000000000000000000000..05133ebec6073d8f4906ca5e8bdd03a3aae4a345 --- /dev/null +++ b/pipecat/processors/filters/function_filter.py @@ -0,0 +1,30 @@ +# +# Copyright (c) 2024, Daily +# +# SPDX-License-Identifier: BSD 2-Clause License +# + +from typing import Awaitable, Callable + +from pipecat.frames.frames import Frame, SystemFrame +from pipecat.processors.frame_processor import FrameDirection, FrameProcessor + + +class FunctionFilter(FrameProcessor): + + def __init__(self, filter: Callable[[Frame], Awaitable[bool]]): + super().__init__() + self._filter = filter + + # + # Frame processor + # + + def _should_passthrough_frame(self, frame): + return isinstance(frame, SystemFrame) + + async def process_frame(self, frame: Frame, direction: FrameDirection): + passthrough = self._should_passthrough_frame(frame) + allowed = await self._filter(frame) + if passthrough or allowed: + await self.push_frame(frame, direction) diff --git a/pipecat/processors/filters/wake_check_filter.py b/pipecat/processors/filters/wake_check_filter.py new file mode 100644 index 0000000000000000000000000000000000000000..621532132dcb1faf73b4cbde3bb8ef30c34ce7ff --- /dev/null +++ b/pipecat/processors/filters/wake_check_filter.py @@ -0,0 +1,86 @@ +# +# Copyright (c) 2024, Daily +# +# SPDX-License-Identifier: BSD 2-Clause License +# + +import re +import time + +from enum import Enum + +from pipecat.frames.frames import ErrorFrame, Frame, TranscriptionFrame +from pipecat.processors.frame_processor import FrameDirection, FrameProcessor + +from loguru import logger + + +class WakeCheckFilter(FrameProcessor): + """ + This filter looks for wake phrases in the transcription frames and only passes through frames + after a wake phrase has been detected. It also has a keepalive timeout to allow for a brief + period of continued conversation after a wake phrase has been detected. + """ + class WakeState(Enum): + IDLE = 1 + AWAKE = 2 + + class ParticipantState: + def __init__(self, participant_id: str): + self.participant_id = participant_id + self.state = WakeCheckFilter.WakeState.IDLE + self.wake_timer = 0.0 + self.accumulator = "" + + def __init__(self, wake_phrases: list[str], keepalive_timeout: float = 3): + super().__init__() + self._participant_states = {} + self._keepalive_timeout = keepalive_timeout + self._wake_patterns = [] + for name in wake_phrases: + pattern = re.compile(r'\b' + r'\s*'.join(re.escape(word) + for word in name.split()) + r'\b', re.IGNORECASE) + self._wake_patterns.append(pattern) + + async def process_frame(self, frame: Frame, direction: FrameDirection): + await super().process_frame(frame, direction) + + try: + if isinstance(frame, TranscriptionFrame): + p = self._participant_states.get(frame.user_id) + if p is None: + p = WakeCheckFilter.ParticipantState(frame.user_id) + self._participant_states[frame.user_id] = p + + # If we have been AWAKE within the last keepalive_timeout seconds, pass + # the frame through + if p.state == WakeCheckFilter.WakeState.AWAKE: + if time.time() - p.wake_timer < self._keepalive_timeout: + logger.debug( + f"Wake phrase keepalive timeout has not expired. Pushing {frame}") + p.wake_timer = time.time() + await self.push_frame(frame) + return + else: + p.state = WakeCheckFilter.WakeState.IDLE + + p.accumulator += frame.text + for pattern in self._wake_patterns: + match = pattern.search(p.accumulator) + if match: + logger.debug(f"Wake phrase triggered: {match.group()}") + # Found the wake word. Discard from the accumulator up to the start of the match + # and modify the frame in place. + p.state = WakeCheckFilter.WakeState.AWAKE + p.wake_timer = time.time() + frame.text = p.accumulator[match.start():] + p.accumulator = "" + await self.push_frame(frame) + else: + pass + else: + await self.push_frame(frame, direction) + except Exception as e: + error_msg = f"Error in wake word filter: {e}" + logger.exception(error_msg) + await self.push_error(ErrorFrame(error_msg)) diff --git a/pipecat/processors/frame_processor.py b/pipecat/processors/frame_processor.py new file mode 100644 index 0000000000000000000000000000000000000000..3e53034f46b53d91a189e422b4ae81aecb1cdbd3 --- /dev/null +++ b/pipecat/processors/frame_processor.py @@ -0,0 +1,162 @@ +# +# Copyright (c) 2024, Daily +# +# SPDX-License-Identifier: BSD 2-Clause License +# + +import asyncio +import time + +from enum import Enum + +from pipecat.frames.frames import ErrorFrame, Frame, MetricsFrame, StartFrame, StartInterruptionFrame, UserStoppedSpeakingFrame +from pipecat.utils.utils import obj_count, obj_id + +from loguru import logger + + +class FrameDirection(Enum): + DOWNSTREAM = 1 + UPSTREAM = 2 + + +class FrameProcessorMetrics: + def __init__(self, name: str): + self._name = name + self._start_ttfb_time = 0 + self._start_processing_time = 0 + self._should_report_ttfb = True + + async def start_ttfb_metrics(self, report_only_initial_ttfb): + if self._should_report_ttfb: + self._start_ttfb_time = time.time() + self._should_report_ttfb = not report_only_initial_ttfb + + async def stop_ttfb_metrics(self): + if self._start_ttfb_time == 0: + return None + + value = time.time() - self._start_ttfb_time + logger.debug(f"{self._name} TTFB: {value}") + ttfb = { + "processor": self._name, + "value": value + } + self._start_ttfb_time = 0 + return MetricsFrame(ttfb=[ttfb]) + + async def start_processing_metrics(self): + self._start_processing_time = time.time() + + async def stop_processing_metrics(self): + if self._start_processing_time == 0: + return None + + value = time.time() - self._start_processing_time + logger.debug(f"{self._name} processing time: {value}") + processing = { + "processor": self._name, + "value": value + } + self._start_processing_time = 0 + return MetricsFrame(processing=[processing]) + + +class FrameProcessor: + + def __init__( + self, + *, + name: str | None = None, + loop: asyncio.AbstractEventLoop | None = None, + **kwargs): + self.id: int = obj_id() + self.name = name or f"{self.__class__.__name__}#{obj_count(self)}" + self._prev: "FrameProcessor" | None = None + self._next: "FrameProcessor" | None = None + self._loop: asyncio.AbstractEventLoop = loop or asyncio.get_running_loop() + + # Properties + self._allow_interruptions = False + self._enable_metrics = False + self._report_only_initial_ttfb = False + + # Metrics + self._metrics = FrameProcessorMetrics(name=self.name) + + @property + def interruptions_allowed(self): + return self._allow_interruptions + + @property + def metrics_enabled(self): + return self._enable_metrics + + @property + def report_only_initial_ttfb(self): + return self._report_only_initial_ttfb + + def can_generate_metrics(self) -> bool: + return False + + async def start_ttfb_metrics(self): + if self.can_generate_metrics() and self.metrics_enabled: + await self._metrics.start_ttfb_metrics(self._report_only_initial_ttfb) + + async def stop_ttfb_metrics(self): + if self.can_generate_metrics() and self.metrics_enabled: + frame = await self._metrics.stop_ttfb_metrics() + if frame: + await self.push_frame(frame) + + async def start_processing_metrics(self): + if self.can_generate_metrics() and self.metrics_enabled: + await self._metrics.start_processing_metrics() + + async def stop_processing_metrics(self): + if self.can_generate_metrics() and self.metrics_enabled: + frame = await self._metrics.stop_processing_metrics() + if frame: + await self.push_frame(frame) + + async def stop_all_metrics(self): + await self.stop_ttfb_metrics() + await self.stop_processing_metrics() + + async def cleanup(self): + pass + + def link(self, processor: 'FrameProcessor'): + self._next = processor + processor._prev = self + logger.debug(f"Linking {self} -> {self._next}") + + def get_event_loop(self) -> asyncio.AbstractEventLoop: + return self._loop + + async def process_frame(self, frame: Frame, direction: FrameDirection): + if isinstance(frame, StartFrame): + self._allow_interruptions = frame.allow_interruptions + self._enable_metrics = frame.enable_metrics + self._report_only_initial_ttfb = frame.report_only_initial_ttfb + elif isinstance(frame, StartInterruptionFrame): + await self.stop_all_metrics() + elif isinstance(frame, UserStoppedSpeakingFrame): + self._should_report_ttfb = True + + async def push_error(self, error: ErrorFrame): + await self.push_frame(error, FrameDirection.UPSTREAM) + + async def push_frame(self, frame: Frame, direction: FrameDirection = FrameDirection.DOWNSTREAM): + try: + if direction == FrameDirection.DOWNSTREAM and self._next: + logger.trace(f"Pushing {frame} from {self} to {self._next}") + await self._next.process_frame(frame, direction) + elif direction == FrameDirection.UPSTREAM and self._prev: + logger.trace(f"Pushing {frame} upstream from {self} to {self._prev}") + await self._prev.process_frame(frame, direction) + except Exception as e: + logger.exception(f"Uncaught exception in {self}: {e}") + + def __str__(self): + return self.name diff --git a/pipecat/processors/frameworks/__init__.py b/pipecat/processors/frameworks/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/pipecat/processors/frameworks/langchain.py b/pipecat/processors/frameworks/langchain.py new file mode 100644 index 0000000000000000000000000000000000000000..628d3300fe21507dd3f95c0720ad3f6a89556151 --- /dev/null +++ b/pipecat/processors/frameworks/langchain.py @@ -0,0 +1,80 @@ +# +# Copyright (c) 2024, Daily +# +# SPDX-License-Identifier: BSD 2-Clause License +# + +from typing import Union + +from pipecat.frames.frames import ( + Frame, + LLMFullResponseEndFrame, + LLMFullResponseStartFrame, + LLMMessagesFrame, + LLMResponseEndFrame, + LLMResponseStartFrame, + TextFrame) +from pipecat.processors.frame_processor import FrameDirection, FrameProcessor + +from loguru import logger + +try: + from langchain_core.messages import AIMessageChunk + from langchain_core.runnables import Runnable +except ModuleNotFoundError as e: + logger.exception( + "In order to use Langchain, you need to `pip install pipecat-ai[langchain]`. " + ) + raise Exception(f"Missing module: {e}") + + +class LangchainProcessor(FrameProcessor): + def __init__(self, chain: Runnable, transcript_key: str = "input"): + super().__init__() + self._chain = chain + self._transcript_key = transcript_key + self._participant_id: str | None = None + + def set_participant_id(self, participant_id: str): + self._participant_id = participant_id + + async def process_frame(self, frame: Frame, direction: FrameDirection): + await super().process_frame(frame, direction) + + if isinstance(frame, LLMMessagesFrame): + # Messages are accumulated by the `LLMUserResponseAggregator` in a list of messages. + # The last one by the human is the one we want to send to the LLM. + logger.debug(f"Got transcription frame {frame}") + text: str = frame.messages[-1]["content"] + + await self._ainvoke(text.strip()) + else: + await self.push_frame(frame, direction) + + @staticmethod + def __get_token_value(text: Union[str, AIMessageChunk]) -> str: + match text: + case str(): + return text + case AIMessageChunk(): + return text.content + case _: + return "" + + async def _ainvoke(self, text: str): + logger.debug(f"Invoking chain with {text}") + await self.push_frame(LLMFullResponseStartFrame()) + try: + async for token in self._chain.astream( + {self._transcript_key: text}, + config={"configurable": {"session_id": self._participant_id}}, + ): + await self.push_frame(LLMResponseStartFrame()) + await self.push_frame(TextFrame(self.__get_token_value(token))) + await self.push_frame(LLMResponseEndFrame()) + except GeneratorExit: + logger.warning(f"{self} generator was closed prematurely") + except Exception as e: + logger.exception(f"{self} an unknown error occurred: {e}") + finally: + await self.push_frame(LLMFullResponseEndFrame()) diff --git a/pipecat/processors/logger.py b/pipecat/processors/logger.py new file mode 100644 index 0000000000000000000000000000000000000000..0417bf0ce31760e62ed3d91b2f7ba59a46c1e344 --- /dev/null +++ b/pipecat/processors/logger.py @@ -0,0 +1,27 @@ +# +# Copyright (c) 2024, Daily +# +# SPDX-License-Identifier: BSD 2-Clause License +# + +from pipecat.frames.frames import Frame +from pipecat.processors.frame_processor import FrameDirection, FrameProcessor +from loguru import logger +from typing import Optional +logger = logger.opt(ansi=True) + + +class FrameLogger(FrameProcessor): + def __init__(self, prefix="Frame", color: Optional[str] = None): + super().__init__() + self._prefix = prefix + self._color = color + + async def process_frame(self, frame: Frame, direction: FrameDirection): + dir = "<" if direction is FrameDirection.UPSTREAM else ">" + msg = f"{dir} {self._prefix}: {frame}" + if self._color: + msg = f"<{self._color}>{msg}" + logger.debug(msg) + + await self.push_frame(frame, direction) diff --git a/pipecat/processors/text_transformer.py b/pipecat/processors/text_transformer.py new file mode 100644 index 0000000000000000000000000000000000000000..65a0e2d113baa2752c66ba4f09a153fae774567e --- /dev/null +++ b/pipecat/processors/text_transformer.py @@ -0,0 +1,38 @@ +# +# Copyright (c) 2024, Daily +# +# SPDX-License-Identifier: BSD 2-Clause License +# + +from typing import Coroutine + +from pipecat.frames.frames import Frame, TextFrame +from pipecat.processors.frame_processor import FrameDirection, FrameProcessor + + +class StatelessTextTransformer(FrameProcessor): + """This processor calls the given function on any text in a text frame. + + >>> async def print_frames(aggregator, frame): + ... async for frame in aggregator.process_frame(frame): + ... print(frame.text) + + >>> aggregator = StatelessTextTransformer(lambda x: x.upper()) + >>> asyncio.run(print_frames(aggregator, TextFrame("Hello"))) + HELLO + """ + + def __init__(self, transform_fn): + super().__init__() + self._transform_fn = transform_fn + + async def process_frame(self, frame: Frame, direction: FrameDirection): + await super().process_frame(frame, direction) + + if isinstance(frame, TextFrame): + result = self._transform_fn(frame.text) + if isinstance(result, Coroutine): + result = await result + await self.push_frame(result) + else: + await self.push_frame(frame, direction) diff --git a/pipecat/serializers/__init__.py b/pipecat/serializers/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/pipecat/serializers/base_serializer.py b/pipecat/serializers/base_serializer.py new file mode 100644 index 0000000000000000000000000000000000000000..e690d1c503bbc7298dfa3bd019a90630c8f270c3 --- /dev/null +++ b/pipecat/serializers/base_serializer.py @@ -0,0 +1,20 @@ +# +# Copyright (c) 2024, Daily +# +# SPDX-License-Identifier: BSD 2-Clause License +# + +from abc import ABC, abstractmethod + +from pipecat.frames.frames import Frame + + +class FrameSerializer(ABC): + + @abstractmethod + def serialize(self, frame: Frame) -> str | bytes | None: + pass + + @abstractmethod + def deserialize(self, data: str | bytes) -> Frame | None: + pass diff --git a/pipecat/serializers/protobuf.py b/pipecat/serializers/protobuf.py new file mode 100644 index 0000000000000000000000000000000000000000..60ce36cf54f52efede09190580a62c2f4f500bf4 --- /dev/null +++ b/pipecat/serializers/protobuf.py @@ -0,0 +1,92 @@ +# +# Copyright (c) 2024, Daily +# +# SPDX-License-Identifier: BSD 2-Clause License +# + +import dataclasses + +import pipecat.frames.protobufs.frames_pb2 as frame_protos + +from pipecat.frames.frames import AudioRawFrame, Frame, TextFrame, TranscriptionFrame +from pipecat.serializers.base_serializer import FrameSerializer + +from loguru import logger + + +class ProtobufFrameSerializer(FrameSerializer): + SERIALIZABLE_TYPES = { + TextFrame: "text", + AudioRawFrame: "audio", + TranscriptionFrame: "transcription" + } + + SERIALIZABLE_FIELDS = {v: k for k, v in SERIALIZABLE_TYPES.items()} + + def __init__(self): + pass + + def serialize(self, frame: Frame) -> str | bytes | None: + proto_frame = frame_protos.Frame() + if type(frame) not in self.SERIALIZABLE_TYPES: + raise ValueError( + f"Frame type {type(frame)} is not serializable. You may need to add it to ProtobufFrameSerializer.SERIALIZABLE_FIELDS.") + + # ignoring linter errors; we check that type(frame) is in this dict above + proto_optional_name = self.SERIALIZABLE_TYPES[type(frame)] # type: ignore + for field in dataclasses.fields(frame): # type: ignore + setattr(getattr(proto_frame, proto_optional_name), field.name, + getattr(frame, field.name)) + + result = proto_frame.SerializeToString() + return result + + def deserialize(self, data: str | bytes) -> Frame | None: + """Returns a Frame object from a Frame protobuf. Used to convert frames + passed over the wire as protobufs to Frame objects used in pipelines + and frame processors. + + >>> serializer = ProtobufFrameSerializer() + >>> serializer.deserialize( + ... serializer.serialize(AudioFrame(data=b'1234567890'))) + AudioFrame(data=b'1234567890') + + >>> serializer.deserialize( + ... serializer.serialize(TextFrame(text='hello world'))) + TextFrame(text='hello world') + + >>> serializer.deserialize(serializer.serialize(TranscriptionFrame( + ... text="Hello there!", participantId="123", timestamp="2021-01-01"))) + TranscriptionFrame(text='Hello there!', participantId='123', timestamp='2021-01-01') + """ + + proto = frame_protos.Frame.FromString(data) + which = proto.WhichOneof("frame") + if which not in self.SERIALIZABLE_FIELDS: + logger.error("Unable to deserialize a valid frame") + return None + + class_name = self.SERIALIZABLE_FIELDS[which] + args = getattr(proto, which) + args_dict = {} + for field in proto.DESCRIPTOR.fields_by_name[which].message_type.fields: + args_dict[field.name] = getattr(args, field.name) + + # Remove special fields if needed + id = getattr(args, "id") + name = getattr(args, "name") + if not id: + del args_dict["id"] + if not name: + del args_dict["name"] + + # Create the instance + instance = class_name(**args_dict) + + # Set special fields + if id: + setattr(instance, "id", getattr(args, "id")) + if name: + setattr(instance, "name", getattr(args, "name")) + + return instance diff --git a/pipecat/serializers/twilio.py b/pipecat/serializers/twilio.py new file mode 100644 index 0000000000000000000000000000000000000000..bd6792d39e7d019e71bc170b1adecdfe6e4a0d24 --- /dev/null +++ b/pipecat/serializers/twilio.py @@ -0,0 +1,52 @@ +# +# Copyright (c) 2024, Daily +# +# SPDX-License-Identifier: BSD 2-Clause License +# + +import base64 +import json + +from pipecat.frames.frames import AudioRawFrame, Frame +from pipecat.serializers.base_serializer import FrameSerializer +from pipecat.utils.audio import ulaw_8000_to_pcm_16000, pcm_16000_to_ulaw_8000 + + +class TwilioFrameSerializer(FrameSerializer): + SERIALIZABLE_TYPES = { + AudioRawFrame: "audio", + } + + def __init__(self, stream_sid: str): + self._stream_sid = stream_sid + + def serialize(self, frame: Frame) -> str | bytes | None: + if not isinstance(frame, AudioRawFrame): + return None + + data = frame.audio + + serialized_data = pcm_16000_to_ulaw_8000(data) + payload = base64.b64encode(serialized_data).decode("utf-8") + answer = { + "event": "media", + "streamSid": self._stream_sid, + "media": { + "payload": payload + } + } + + return json.dumps(answer) + + def deserialize(self, data: str | bytes) -> Frame | None: + message = json.loads(data) + + if message["event"] != "media": + return None + else: + payload_base64 = message["media"]["payload"] + payload = base64.b64decode(payload_base64) + + deserialized_data = ulaw_8000_to_pcm_16000(payload) + audio_frame = AudioRawFrame(audio=deserialized_data, num_channels=1, sample_rate=16000) + return audio_frame diff --git a/pipecat/services/__init__.py b/pipecat/services/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/pipecat/services/ai_services.py b/pipecat/services/ai_services.py new file mode 100644 index 0000000000000000000000000000000000000000..0977c6e9efe5470a959329aa6a5e6bac53695335 --- /dev/null +++ b/pipecat/services/ai_services.py @@ -0,0 +1,300 @@ +# +# Copyright (c) 2024, Daily +# +# SPDX-License-Identifier: BSD 2-Clause License +# + +import io +import wave + +from abc import abstractmethod +from typing import AsyncGenerator + +from pipecat.frames.frames import ( + AudioRawFrame, + CancelFrame, + EndFrame, + ErrorFrame, + Frame, + LLMFullResponseEndFrame, + StartFrame, + StartInterruptionFrame, + TTSStartedFrame, + TTSStoppedFrame, + TextFrame, + VisionImageRawFrame, +) +from pipecat.processors.async_frame_processor import AsyncFrameProcessor +from pipecat.processors.frame_processor import FrameDirection, FrameProcessor +from pipecat.utils.audio import calculate_audio_volume +from pipecat.utils.utils import exp_smoothing + + +class AIService(FrameProcessor): + def __init__(self, **kwargs): + super().__init__(**kwargs) + + async def start(self, frame: StartFrame): + pass + + async def stop(self, frame: EndFrame): + pass + + async def cancel(self, frame: CancelFrame): + pass + + async def process_frame(self, frame: Frame, direction: FrameDirection): + await super().process_frame(frame, direction) + + if isinstance(frame, StartFrame): + await self.start(frame) + elif isinstance(frame, CancelFrame): + await self.cancel(frame) + elif isinstance(frame, EndFrame): + await self.stop(frame) + + async def process_generator(self, generator: AsyncGenerator[Frame, None]): + async for f in generator: + if isinstance(f, ErrorFrame): + await self.push_error(f) + else: + await self.push_frame(f) + + +class AsyncAIService(AsyncFrameProcessor): + def __init__(self, **kwargs): + super().__init__(**kwargs) + + async def start(self, frame: StartFrame): + pass + + async def stop(self, frame: EndFrame): + pass + + async def cancel(self, frame: CancelFrame): + pass + + async def process_frame(self, frame: Frame, direction: FrameDirection): + await super().process_frame(frame, direction) + + if isinstance(frame, StartFrame): + await self.start(frame) + elif isinstance(frame, CancelFrame): + await self.cancel(frame) + elif isinstance(frame, EndFrame): + await self.stop(frame) + + +class LLMService(AIService): + """This class is a no-op but serves as a base class for LLM services.""" + + def __init__(self, **kwargs): + super().__init__(**kwargs) + self._callbacks = {} + self._start_callbacks = {} + + # TODO-CB: callback function type + def register_function(self, function_name: str, callback, start_callback=None): + self._callbacks[function_name] = callback + if start_callback: + self._start_callbacks[function_name] = start_callback + + def unregister_function(self, function_name: str): + del self._callbacks[function_name] + if self._start_callbacks[function_name]: + del self._start_callbacks[function_name] + + def has_function(self, function_name: str): + return function_name in self._callbacks.keys() + + async def call_function(self, function_name: str, args): + if function_name in self._callbacks.keys(): + return await self._callbacks[function_name](self, args) + return None + + async def call_start_function(self, function_name: str): + if function_name in self._start_callbacks.keys(): + await self._start_callbacks[function_name](self) + + +class TTSService(AIService): + def __init__(self, *, aggregate_sentences: bool = True, **kwargs): + super().__init__(**kwargs) + self._aggregate_sentences: bool = aggregate_sentences + self._current_sentence: str = "" + + # Converts the text to audio. + @abstractmethod + async def run_tts(self, text: str) -> AsyncGenerator[Frame, None]: + pass + + async def say(self, text: str): + await self.process_frame(TextFrame(text=text), FrameDirection.DOWNSTREAM) + + async def _process_text_frame(self, frame: TextFrame): + text: str | None = None + if not self._aggregate_sentences: + text = frame.text + else: + self._current_sentence += frame.text + if self._current_sentence.strip().endswith( + (".", "?", "!")) and not self._current_sentence.strip().endswith( + ("Mr,", "Mrs.", "Ms.", "Dr.")): + text = self._current_sentence + self._current_sentence = "" + + if text: + await self._push_tts_frames(text) + + async def _push_tts_frames(self, text: str): + text = text.strip() + if not text: + return + + await self.push_frame(TTSStartedFrame()) + await self.start_processing_metrics() + await self.process_generator(self.run_tts(text)) + await self.stop_processing_metrics() + await self.push_frame(TTSStoppedFrame()) + # We send the original text after the audio. This way, if we are + # interrupted, the text is not added to the assistant context. + await self.push_frame(TextFrame(text)) + + async def process_frame(self, frame: Frame, direction: FrameDirection): + await super().process_frame(frame, direction) + + if isinstance(frame, TextFrame): + await self._process_text_frame(frame) + elif isinstance(frame, StartInterruptionFrame): + self._current_sentence = "" + await self.push_frame(frame, direction) + elif isinstance(frame, LLMFullResponseEndFrame) or isinstance(frame, EndFrame): + self._current_sentence = "" + await self._push_tts_frames(self._current_sentence) + await self.push_frame(frame) + else: + await self.push_frame(frame, direction) + + +class STTService(AIService): + """STTService is a base class for speech-to-text services.""" + + def __init__(self, + *, + min_volume: float = 0.6, + max_silence_secs: float = 0.3, + max_buffer_secs: float = 1.5, + sample_rate: int = 16000, + num_channels: int = 1, + **kwargs): + super().__init__(**kwargs) + self._min_volume = min_volume + self._max_silence_secs = max_silence_secs + self._max_buffer_secs = max_buffer_secs + self._sample_rate = sample_rate + self._num_channels = num_channels + (self._content, self._wave) = self._new_wave() + self._silence_num_frames = 0 + # Volume exponential smoothing + self._smoothing_factor = 0.2 + self._prev_volume = 0 + + @abstractmethod + async def run_stt(self, audio: bytes) -> AsyncGenerator[Frame, None]: + """Returns transcript as a string""" + pass + + def _new_wave(self): + content = io.BytesIO() + ww = wave.open(content, "wb") + ww.setsampwidth(2) + ww.setnchannels(self._num_channels) + ww.setframerate(self._sample_rate) + return (content, ww) + + def _get_smoothed_volume(self, frame: AudioRawFrame) -> float: + volume = calculate_audio_volume(frame.audio, frame.sample_rate) + return exp_smoothing(volume, self._prev_volume, self._smoothing_factor) + + async def _append_audio(self, frame: AudioRawFrame): + # Try to filter out empty background noise + volume = self._get_smoothed_volume(frame) + if volume >= self._min_volume: + # If volume is high enough, write new data to wave file + self._wave.writeframes(frame.audio) + self._silence_num_frames = 0 + else: + self._silence_num_frames += frame.num_frames + self._prev_volume = volume + + # If buffer is not empty and we have enough data or there's been a long + # silence, transcribe the audio gathered so far. + silence_secs = self._silence_num_frames / self._sample_rate + buffer_secs = self._wave.getnframes() / self._sample_rate + if self._content.tell() > 0 and ( + buffer_secs > self._max_buffer_secs or silence_secs > self._max_silence_secs): + self._silence_num_frames = 0 + self._wave.close() + self._content.seek(0) + await self.start_processing_metrics() + await self.process_generator(self.run_stt(self._content.read())) + await self.stop_processing_metrics() + (self._content, self._wave) = self._new_wave() + + async def process_frame(self, frame: Frame, direction: FrameDirection): + """Processes a frame of audio data, either buffering or transcribing it.""" + await super().process_frame(frame, direction) + + if isinstance(frame, CancelFrame) or isinstance(frame, EndFrame): + self._wave.close() + await self.push_frame(frame, direction) + elif isinstance(frame, AudioRawFrame): + # In this service we accumulate audio internally and at the end we + # push a TextFrame. We don't really want to push audio frames down. + await self._append_audio(frame) + else: + await self.push_frame(frame, direction) + + +class ImageGenService(AIService): + + def __init__(self, **kwargs): + super().__init__(**kwargs) + + # Renders the image. Returns an Image object. + @abstractmethod + async def run_image_gen(self, prompt: str) -> AsyncGenerator[Frame, None]: + pass + + async def process_frame(self, frame: Frame, direction: FrameDirection): + await super().process_frame(frame, direction) + + if isinstance(frame, TextFrame): + await self.push_frame(frame, direction) + await self.start_processing_metrics() + await self.process_generator(self.run_image_gen(frame.text)) + await self.stop_processing_metrics() + else: + await self.push_frame(frame, direction) + + +class VisionService(AIService): + """VisionService is a base class for vision services.""" + + def __init__(self, **kwargs): + super().__init__(**kwargs) + self._describe_text = None + + @abstractmethod + async def run_vision(self, frame: VisionImageRawFrame) -> AsyncGenerator[Frame, None]: + pass + + async def process_frame(self, frame: Frame, direction: FrameDirection): + await super().process_frame(frame, direction) + + if isinstance(frame, VisionImageRawFrame): + await self.start_processing_metrics() + await self.process_generator(self.run_vision(frame)) + await self.stop_processing_metrics() + else: + await self.push_frame(frame, direction) diff --git a/pipecat/services/anthropic.py b/pipecat/services/anthropic.py new file mode 100644 index 0000000000000000000000000000000000000000..50d01965676028e11483c84f68d8264389587d91 --- /dev/null +++ b/pipecat/services/anthropic.py @@ -0,0 +1,145 @@ +# +# Copyright (c) 2024, Daily +# +# SPDX-License-Identifier: BSD 2-Clause License +# + +import base64 + +from pipecat.frames.frames import ( + Frame, + TextFrame, + VisionImageRawFrame, + LLMMessagesFrame, + LLMFullResponseStartFrame, + LLMResponseStartFrame, + LLMResponseEndFrame, + LLMFullResponseEndFrame +) +from pipecat.processors.frame_processor import FrameDirection +from pipecat.services.ai_services import LLMService +from pipecat.processors.aggregators.openai_llm_context import OpenAILLMContext, OpenAILLMContextFrame + +from loguru import logger + +try: + from anthropic import AsyncAnthropic +except ModuleNotFoundError as e: + logger.error(f"Exception: {e}") + logger.error( + "In order to use Anthropic, you need to `pip install pipecat-ai[anthropic]`. Also, set `ANTHROPIC_API_KEY` environment variable.") + raise Exception(f"Missing module: {e}") + + +class AnthropicLLMService(LLMService): + """This class implements inference with Anthropic's AI models + + This service translates internally from OpenAILLMContext to the messages format + expected by the Anthropic Python SDK. We are using the OpenAILLMContext as a lingua + franca for all LLM services, so that it is easy to switch between different LLMs. + """ + + def __init__( + self, + *, + api_key: str, + model: str = "claude-3-opus-20240229", + max_tokens: int = 1024): + super().__init__() + self._client = AsyncAnthropic(api_key=api_key) + self._model = model + self._max_tokens = max_tokens + + def can_generate_metrics(self) -> bool: + return True + + def _get_messages_from_openai_context( + self, context: OpenAILLMContext): + openai_messages = context.get_messages() + anthropic_messages = [] + + for message in openai_messages: + role = message["role"] + text = message["content"] + if role == "system": + role = "user" + if message.get("mime_type") == "image/jpeg": + # vision frame + encoded_image = base64.b64encode(message["data"].getvalue()).decode("utf-8") + anthropic_messages.append({ + "role": role, + "content": [{ + "type": "image", + "source": { + "type": "base64", + "media_type": message.get("mime_type"), + "data": encoded_image, + } + }, { + "type": "text", + "text": text + }] + }) + else: + # Text frame. Anthropic needs the roles to alternate. This will + # cause an issue with interruptions. So, if we detect we are the + # ones asking again it probably means we were interrupted. + if role == "user" and len(anthropic_messages) > 1: + last_message = anthropic_messages[-1] + if last_message["role"] == "user": + anthropic_messages = anthropic_messages[:-1] + content = last_message["content"] + anthropic_messages.append( + {"role": "user", "content": f"Sorry, I just asked you about [{content}] but now I would like to know [{text}]."}) + else: + anthropic_messages.append({"role": role, "content": text}) + else: + anthropic_messages.append({"role": role, "content": text}) + + return anthropic_messages + + async def _process_context(self, context: OpenAILLMContext): + await self.push_frame(LLMFullResponseStartFrame()) + try: + logger.debug(f"Generating chat: {context.get_messages_json()}") + + messages = self._get_messages_from_openai_context(context) + + await self.start_ttfb_metrics() + + response = await self._client.messages.create( + messages=messages, + model=self._model, + max_tokens=self._max_tokens, + stream=True) + + await self.stop_ttfb_metrics() + + async for event in response: + # logger.debug(f"Anthropic LLM event: {event}") + if (event.type == "content_block_delta"): + await self.push_frame(LLMResponseStartFrame()) + await self.push_frame(TextFrame(event.delta.text)) + await self.push_frame(LLMResponseEndFrame()) + + except Exception as e: + logger.exception(f"{self} exception: {e}") + finally: + await self.push_frame(LLMFullResponseEndFrame()) + + async def process_frame(self, frame: Frame, direction: FrameDirection): + await super().process_frame(frame, direction) + + context = None + + if isinstance(frame, OpenAILLMContextFrame): + context: OpenAILLMContext = frame.context + elif isinstance(frame, LLMMessagesFrame): + context = OpenAILLMContext.from_messages(frame.messages) + elif isinstance(frame, VisionImageRawFrame): + context = OpenAILLMContext.from_image_frame(frame) + else: + await self.push_frame(frame, direction) + + if context: + await self._process_context(context) diff --git a/pipecat/services/azure.py b/pipecat/services/azure.py new file mode 100644 index 0000000000000000000000000000000000000000..10b8a290b1354f52651379fab830222826ded604 --- /dev/null +++ b/pipecat/services/azure.py @@ -0,0 +1,233 @@ +# +# Copyright (c) 2024, Daily +# +# SPDX-License-Identifier: BSD 2-Clause License +# + +import aiohttp +import asyncio +import io +import time + +from PIL import Image +from typing import AsyncGenerator + +from pipecat.frames.frames import ( + AudioRawFrame, + CancelFrame, + EndFrame, + ErrorFrame, + Frame, + StartFrame, + StartInterruptionFrame, + SystemFrame, + TranscriptionFrame, + URLImageRawFrame) +from pipecat.processors.frame_processor import FrameDirection +from pipecat.services.ai_services import AIService, AsyncAIService, TTSService, ImageGenService +from pipecat.services.openai import BaseOpenAILLMService + +from loguru import logger + +# See .env.example for Azure configuration needed +try: + from openai import AsyncAzureOpenAI + from azure.cognitiveservices.speech import ( + SpeechConfig, + SpeechRecognizer, + SpeechSynthesizer, + ResultReason, + CancellationReason, + ) + from azure.cognitiveservices.speech.audio import AudioStreamFormat, PushAudioInputStream + from azure.cognitiveservices.speech.dialog import AudioConfig +except ModuleNotFoundError as e: + logger.error(f"Exception: {e}") + logger.error( + "In order to use Azure, you need to `pip install pipecat-ai[azure]`. Also, set `AZURE_SPEECH_API_KEY` and `AZURE_SPEECH_REGION` environment variables.") + raise Exception(f"Missing module: {e}") + + +class AzureLLMService(BaseOpenAILLMService): + def __init__( + self, + *, + api_key: str, + endpoint: str, + model: str, + api_version: str = "2023-12-01-preview"): + # Initialize variables before calling parent __init__() because that + # will call create_client() and we need those values there. + self._endpoint = endpoint + self._api_version = api_version + super().__init__(api_key=api_key, model=model) + + def create_client(self, api_key=None, base_url=None, **kwargs): + return AsyncAzureOpenAI( + api_key=api_key, + azure_endpoint=self._endpoint, + api_version=self._api_version, + ) + + +class AzureTTSService(TTSService): + def __init__(self, *, api_key: str, region: str, voice="en-US-SaraNeural", **kwargs): + super().__init__(**kwargs) + + speech_config = SpeechConfig(subscription=api_key, region=region) + self._speech_synthesizer = SpeechSynthesizer(speech_config=speech_config, audio_config=None) + + self._voice = voice + + def can_generate_metrics(self) -> bool: + return True + + async def run_tts(self, text: str) -> AsyncGenerator[Frame, None]: + logger.debug(f"Generating TTS: {text}") + + await self.start_ttfb_metrics() + + ssml = ( + "" + f"" + "" + "" + "" + f"{text}" + " ") + + result = await asyncio.to_thread(self._speech_synthesizer.speak_ssml, (ssml)) + + if result.reason == ResultReason.SynthesizingAudioCompleted: + await self.stop_ttfb_metrics() + # Azure always sends a 44-byte header. Strip it off. + yield AudioRawFrame(audio=result.audio_data[44:], sample_rate=16000, num_channels=1) + elif result.reason == ResultReason.Canceled: + cancellation_details = result.cancellation_details + logger.warning(f"Speech synthesis canceled: {cancellation_details.reason}") + if cancellation_details.reason == CancellationReason.Error: + logger.error(f"{self} error: {cancellation_details.error_details}") + + +class AzureSTTService(AsyncAIService): + def __init__( + self, + *, + api_key: str, + region: str, + language="en-US", + sample_rate=16000, + channels=1, + **kwargs): + super().__init__(**kwargs) + + speech_config = SpeechConfig(subscription=api_key, region=region) + speech_config.speech_recognition_language = language + + stream_format = AudioStreamFormat(samples_per_second=sample_rate, channels=channels) + self._audio_stream = PushAudioInputStream(stream_format) + + audio_config = AudioConfig(stream=self._audio_stream) + self._speech_recognizer = SpeechRecognizer( + speech_config=speech_config, audio_config=audio_config) + self._speech_recognizer.recognized.connect(self._on_handle_recognized) + + async def process_frame(self, frame: Frame, direction: FrameDirection): + await super().process_frame(frame, direction) + + if isinstance(frame, SystemFrame): + await self.push_frame(frame, direction) + elif isinstance(frame, AudioRawFrame): + self._audio_stream.write(frame.audio) + else: + await self._push_queue.put((frame, direction)) + + async def start(self, frame: StartFrame): + self._speech_recognizer.start_continuous_recognition_async() + + async def stop(self, frame: EndFrame): + self._speech_recognizer.stop_continuous_recognition_async() + + async def cancel(self, frame: CancelFrame): + self._speech_recognizer.stop_continuous_recognition_async() + + def _on_handle_recognized(self, event): + if event.result.reason == ResultReason.RecognizedSpeech and len(event.result.text) > 0: + frame = TranscriptionFrame(event.result.text, "", int(time.time_ns() / 1000000)) + asyncio.run_coroutine_threadsafe(self.queue_frame(frame), self.get_event_loop()) + + +class AzureImageGenServiceREST(ImageGenService): + + def __init__( + self, + *, + aiohttp_session: aiohttp.ClientSession, + image_size: str, + api_key: str, + endpoint: str, + model: str, + api_version="2023-06-01-preview", + ): + super().__init__() + + self._api_key = api_key + self._azure_endpoint = endpoint + self._api_version = api_version + self._model = model + self._aiohttp_session = aiohttp_session + self._image_size = image_size + + async def run_image_gen(self, prompt: str) -> AsyncGenerator[Frame, None]: + url = f"{self._azure_endpoint}openai/images/generations:submit?api-version={self._api_version}" + + headers = { + "api-key": self._api_key, + "Content-Type": "application/json"} + + body = { + # Enter your prompt text here + "prompt": prompt, + "size": self._image_size, + "n": 1, + } + + async with self._aiohttp_session.post(url, headers=headers, json=body) as submission: + # We never get past this line, because this header isn't + # defined on a 429 response, but something is eating our + # exceptions! + operation_location = submission.headers["operation-location"] + status = "" + attempts_left = 120 + json_response = None + while status != "succeeded": + attempts_left -= 1 + if attempts_left == 0: + logger.error(f"{self} error: image generation timed out") + yield ErrorFrame("Image generation timed out") + return + + await asyncio.sleep(1) + + response = await self._aiohttp_session.get(operation_location, headers=headers) + + json_response = await response.json() + status = json_response["status"] + + image_url = json_response["result"]["data"][0]["url"] if json_response else None + if not image_url: + logger.error(f"{self} error: image generation failed") + yield ErrorFrame("Image generation failed") + return + + # Load the image from the url + async with self._aiohttp_session.get(image_url) as response: + image_stream = io.BytesIO(await response.content.read()) + image = Image.open(image_stream) + frame = URLImageRawFrame( + url=image_url, + image=image.tobytes(), + size=image.size, + format=image.format) + yield frame diff --git a/pipecat/services/cartesia.py b/pipecat/services/cartesia.py new file mode 100644 index 0000000000000000000000000000000000000000..ebed28e76c53307743483d3d81607daff892c0b6 --- /dev/null +++ b/pipecat/services/cartesia.py @@ -0,0 +1,65 @@ +# +# Copyright (c) 2024, Daily +# +# SPDX-License-Identifier: BSD 2-Clause License +# + +from cartesia import AsyncCartesia + +from typing import AsyncGenerator + +from pipecat.frames.frames import AudioRawFrame, Frame +from pipecat.services.ai_services import TTSService + +from loguru import logger + + +class CartesiaTTSService(TTSService): + + def __init__( + self, + *, + api_key: str, + voice_id: str, + model_id: str = "sonic-english", + encoding: str = "pcm_s16le", + sample_rate: int = 16000, + **kwargs): + super().__init__(**kwargs) + + self._api_key = api_key + self._model_id = model_id + self._output_format = { + "container": "raw", + "encoding": encoding, + "sample_rate": sample_rate, + } + + try: + self._client = AsyncCartesia(api_key=self._api_key) + self._voice = self._client.voices.get(id=voice_id) + except Exception as e: + logger.exception(f"{self} initialization error: {e}") + + def can_generate_metrics(self) -> bool: + return True + + async def run_tts(self, text: str) -> AsyncGenerator[Frame, None]: + logger.debug(f"Generating TTS: [{text}]") + + try: + await self.start_ttfb_metrics() + + chunk_generator = await self._client.tts.sse( + stream=True, + transcript=text, + voice_embedding=self._voice["embedding"], + model_id=self._model_id, + output_format=self._output_format, + ) + + async for chunk in chunk_generator: + await self.stop_ttfb_metrics() + yield AudioRawFrame(chunk["audio"], self._output_format["sample_rate"], 1) + except Exception as e: + logger.exception(f"{self} exception: {e}") diff --git a/pipecat/services/deepgram.py b/pipecat/services/deepgram.py new file mode 100644 index 0000000000000000000000000000000000000000..4b8bc817eb0bc8c6aa8c27bb7f1689a856c613bf --- /dev/null +++ b/pipecat/services/deepgram.py @@ -0,0 +1,149 @@ +# +# Copyright (c) 2024, Daily +# +# SPDX-License-Identifier: BSD 2-Clause License +# + +import aiohttp +import time + +from typing import AsyncGenerator + +from pipecat.frames.frames import ( + AudioRawFrame, + CancelFrame, + EndFrame, + ErrorFrame, + Frame, + InterimTranscriptionFrame, + StartFrame, + SystemFrame, + TranscriptionFrame) +from pipecat.processors.frame_processor import FrameDirection +from pipecat.services.ai_services import AsyncAIService, TTSService + +from loguru import logger + +# See .env.example for Deepgram configuration needed +try: + from deepgram import ( + DeepgramClient, + DeepgramClientOptions, + LiveTranscriptionEvents, + LiveOptions, + ) +except ModuleNotFoundError as e: + logger.error(f"Exception: {e}") + logger.error( + "In order to use Deepgram, you need to `pip install pipecat-ai[deepgram]`. Also, set `DEEPGRAM_API_KEY` environment variable.") + raise Exception(f"Missing module: {e}") + + +class DeepgramTTSService(TTSService): + + def __init__( + self, + *, + aiohttp_session: aiohttp.ClientSession, + api_key: str, + voice: str = "aura-helios-en", + base_url: str = "https://api.deepgram.com/v1/speak", + **kwargs): + super().__init__(**kwargs) + + self._voice = voice + self._api_key = api_key + self._aiohttp_session = aiohttp_session + self._base_url = base_url + + def can_generate_metrics(self) -> bool: + return True + + async def run_tts(self, text: str) -> AsyncGenerator[Frame, None]: + logger.debug(f"Generating TTS: [{text}]") + + base_url = self._base_url + request_url = f"{base_url}?model={self._voice}&encoding=linear16&container=none&sample_rate=16000" + headers = {"authorization": f"token {self._api_key}"} + body = {"text": text} + + try: + await self.start_ttfb_metrics() + async with self._aiohttp_session.post(request_url, headers=headers, json=body) as r: + if r.status != 200: + response_text = await r.text() + # If we get a a "Bad Request: Input is unutterable", just print out a debug log. + # All other unsuccesful requests should emit an error frame. If not specifically + # handled by the running PipelineTask, the ErrorFrame will cancel the task. + if "unutterable" in response_text: + logger.debug(f"Unutterable text: [{text}]") + return + + logger.error( + f"{self} error getting audio (status: {r.status}, error: {response_text})") + yield ErrorFrame(f"Error getting audio (status: {r.status}, error: {response_text})") + return + + async for data in r.content: + await self.stop_ttfb_metrics() + frame = AudioRawFrame(audio=data, sample_rate=16000, num_channels=1) + yield frame + except Exception as e: + logger.exception(f"{self} exception: {e}") + + +class DeepgramSTTService(AsyncAIService): + def __init__(self, + *, + api_key: str, + url: str = "", + live_options: LiveOptions = LiveOptions( + encoding="linear16", + language="en-US", + model="nova-2-conversationalai", + sample_rate=16000, + channels=1, + interim_results=True, + smart_format=True, + ), + **kwargs): + super().__init__(**kwargs) + + self._live_options = live_options + + self._client = DeepgramClient( + api_key, config=DeepgramClientOptions(url=url, options={"keepalive": "true"})) + self._connection = self._client.listen.asynclive.v("1") + self._connection.on(LiveTranscriptionEvents.Transcript, self._on_message) + + async def process_frame(self, frame: Frame, direction: FrameDirection): + await super().process_frame(frame, direction) + + if isinstance(frame, SystemFrame): + await self.push_frame(frame, direction) + elif isinstance(frame, AudioRawFrame): + await self._connection.send(frame.audio) + else: + await self.queue_frame(frame, direction) + + async def start(self, frame: StartFrame): + if await self._connection.start(self._live_options): + logger.debug(f"{self}: Connected to Deepgram") + else: + logger.error(f"{self}: Unable to connect to Deepgram") + + async def stop(self, frame: EndFrame): + await self._connection.finish() + + async def cancel(self, frame: CancelFrame): + await self._connection.finish() + + async def _on_message(self, *args, **kwargs): + result = kwargs["result"] + is_final = result.is_final + transcript = result.channel.alternatives[0].transcript + if len(transcript) > 0: + if is_final: + await self.queue_frame(TranscriptionFrame(transcript, "", int(time.time_ns() / 1000000))) + else: + await self.queue_frame(InterimTranscriptionFrame(transcript, "", int(time.time_ns() / 1000000))) diff --git a/pipecat/services/elevenlabs.py b/pipecat/services/elevenlabs.py new file mode 100644 index 0000000000000000000000000000000000000000..bbcdd089c703377e099546d654785ed32a7de862 --- /dev/null +++ b/pipecat/services/elevenlabs.py @@ -0,0 +1,66 @@ +# +# Copyright (c) 2024, Daily +# +# SPDX-License-Identifier: BSD 2-Clause License +# + +import aiohttp + +from typing import AsyncGenerator + +from pipecat.frames.frames import AudioRawFrame, ErrorFrame, Frame +from pipecat.services.ai_services import TTSService + +from loguru import logger + + +class ElevenLabsTTSService(TTSService): + + def __init__( + self, + *, + aiohttp_session: aiohttp.ClientSession, + api_key: str, + voice_id: str, + model: str = "eleven_turbo_v2", + **kwargs): + super().__init__(**kwargs) + + self._api_key = api_key + self._voice_id = voice_id + self._aiohttp_session = aiohttp_session + self._model = model + + def can_generate_metrics(self) -> bool: + return True + + async def run_tts(self, text: str) -> AsyncGenerator[Frame, None]: + logger.debug(f"Generating TTS: [{text}]") + + url = f"https://api.elevenlabs.io/v1/text-to-speech/{self._voice_id}/stream" + + payload = {"text": text, "model_id": self._model} + + querystring = { + "output_format": "pcm_16000", + "optimize_streaming_latency": 2} + + headers = { + "xi-api-key": self._api_key, + "Content-Type": "application/json", + } + + await self.start_ttfb_metrics() + + async with self._aiohttp_session.post(url, json=payload, headers=headers, params=querystring) as r: + if r.status != 200: + text = await r.text() + logger.error(f"{self} error getting audio (status: {r.status}, error: {text})") + yield ErrorFrame(f"Error getting audio (status: {r.status}, error: {text})") + return + + async for chunk in r.content: + if len(chunk) > 0: + await self.stop_ttfb_metrics() + frame = AudioRawFrame(chunk, 16000, 1) + yield frame diff --git a/pipecat/services/fal.py b/pipecat/services/fal.py new file mode 100644 index 0000000000000000000000000000000000000000..e58826f46f21537272a8b0c7926d518564ac122c --- /dev/null +++ b/pipecat/services/fal.py @@ -0,0 +1,83 @@ +# +# Copyright (c) 2024, Daily +# +# SPDX-License-Identifier: BSD 2-Clause License +# + +import aiohttp +import io +import os + +from PIL import Image +from pydantic import BaseModel +from typing import AsyncGenerator, Optional, Union, Dict + +from pipecat.frames.frames import ErrorFrame, Frame, URLImageRawFrame +from pipecat.services.ai_services import ImageGenService + +from loguru import logger + +try: + import fal_client +except ModuleNotFoundError as e: + logger.error(f"Exception: {e}") + logger.error( + "In order to use Fal, you need to `pip install pipecat-ai[fal]`. Also, set `FAL_KEY` environment variable.") + raise Exception(f"Missing module: {e}") + + +class FalImageGenService(ImageGenService): + class InputParams(BaseModel): + seed: Optional[int] = None + num_inference_steps: int = 8 + num_images: int = 1 + image_size: Union[str, Dict[str, int]] = "square_hd" + expand_prompt: bool = False + enable_safety_checker: bool = True + format: str = "png" + + def __init__( + self, + *, + aiohttp_session: aiohttp.ClientSession, + params: InputParams, + model: str = "fal-ai/fast-sdxl", + key: str | None = None, + ): + super().__init__() + self._model = model + self._params = params + self._aiohttp_session = aiohttp_session + if key: + os.environ["FAL_KEY"] = key + + async def run_image_gen(self, prompt: str) -> AsyncGenerator[Frame, None]: + logger.debug(f"Generating image from prompt: {prompt}") + + response = await fal_client.run_async( + self._model, + arguments={"prompt": prompt, **self._params.model_dump()} + ) + + image_url = response["images"][0]["url"] if response else None + + if not image_url: + logger.error(f"{self} error: image generation failed") + yield ErrorFrame("Image generation failed") + return + + logger.debug(f"Image generated at: {image_url}") + + # Load the image from the url + logger.debug(f"Downloading image {image_url} ...") + async with self._aiohttp_session.get(image_url) as response: + logger.debug(f"Downloaded image {image_url}") + image_stream = io.BytesIO(await response.content.read()) + image = Image.open(image_stream) + + frame = URLImageRawFrame( + url=image_url, + image=image.tobytes(), + size=image.size, + format=image.format) + yield frame diff --git a/pipecat/services/fireworks.py b/pipecat/services/fireworks.py new file mode 100644 index 0000000000000000000000000000000000000000..d418ad40302f91bb1864056a7fac86f8437207e2 --- /dev/null +++ b/pipecat/services/fireworks.py @@ -0,0 +1,25 @@ +# +# Copyright (c) 2024, Daily +# +# SPDX-License-Identifier: BSD 2-Clause License +# + +from pipecat.services.openai import BaseOpenAILLMService + +from loguru import logger + +try: + from openai import AsyncOpenAI +except ModuleNotFoundError as e: + logger.error(f"Exception: {e}") + logger.error( + "In order to use Fireworks, you need to `pip install pipecat-ai[fireworks]`. Also, set the `FIREWORKS_API_KEY` environment variable.") + raise Exception(f"Missing module: {e}") + + +class FireworksLLMService(BaseOpenAILLMService): + def __init__(self, + *, + model: str = "accounts/fireworks/models/firefunction-v1", + base_url: str = "https://api.fireworks.ai/inference/v1"): + super().__init__(model, base_url) diff --git a/pipecat/services/google.py b/pipecat/services/google.py new file mode 100644 index 0000000000000000000000000000000000000000..8ad98ed081cef24037cc8496f8fc5e728a7b354a --- /dev/null +++ b/pipecat/services/google.py @@ -0,0 +1,129 @@ +# +# Copyright (c) 2024, Daily +# +# SPDX-License-Identifier: BSD 2-Clause License +# + +import asyncio + +from typing import List + +from pipecat.frames.frames import ( + Frame, + TextFrame, + VisionImageRawFrame, + LLMMessagesFrame, + LLMFullResponseStartFrame, + LLMResponseStartFrame, + LLMResponseEndFrame, + LLMFullResponseEndFrame +) +from pipecat.processors.frame_processor import FrameDirection +from pipecat.services.ai_services import LLMService +from pipecat.processors.aggregators.openai_llm_context import OpenAILLMContext, OpenAILLMContextFrame + +from loguru import logger + +try: + import google.generativeai as gai + import google.ai.generativelanguage as glm +except ModuleNotFoundError as e: + logger.error(f"Exception: {e}") + logger.error( + "In order to use Google AI, you need to `pip install pipecat-ai[google]`. Also, set `GOOGLE_API_KEY` environment variable.") + raise Exception(f"Missing module: {e}") + + +class GoogleLLMService(LLMService): + """This class implements inference with Google's AI models + + This service translates internally from OpenAILLMContext to the messages format + expected by the Google AI model. We are using the OpenAILLMContext as a lingua + franca for all LLM services, so that it is easy to switch between different LLMs. + """ + + def __init__(self, *, api_key: str, model: str = "gemini-1.5-flash-latest", **kwargs): + super().__init__(**kwargs) + gai.configure(api_key=api_key) + self._client = gai.GenerativeModel(model) + + def can_generate_metrics(self) -> bool: + return True + + def _get_messages_from_openai_context( + self, context: OpenAILLMContext) -> List[glm.Content]: + openai_messages = context.get_messages() + google_messages = [] + + for message in openai_messages: + role = message["role"] + content = message["content"] + if role == "system": + role = "user" + elif role == "assistant": + role = "model" + + parts = [glm.Part(text=content)] + if "mime_type" in message: + parts.append( + glm.Part(inline_data=glm.Blob( + mime_type=message["mime_type"], + data=message["data"].getvalue() + ))) + google_messages.append({"role": role, "parts": parts}) + + return google_messages + + async def _async_generator_wrapper(self, sync_generator): + for item in sync_generator: + yield item + await asyncio.sleep(0) + + async def _process_context(self, context: OpenAILLMContext): + await self.push_frame(LLMFullResponseStartFrame()) + try: + logger.debug(f"Generating chat: {context.get_messages_json()}") + + messages = self._get_messages_from_openai_context(context) + + await self.start_ttfb_metrics() + + response = self._client.generate_content(messages, stream=True) + + await self.stop_ttfb_metrics() + + async for chunk in self._async_generator_wrapper(response): + try: + text = chunk.text + await self.push_frame(LLMResponseStartFrame()) + await self.push_frame(TextFrame(text)) + await self.push_frame(LLMResponseEndFrame()) + except Exception as e: + # Google LLMs seem to flag safety issues a lot! + if chunk.candidates[0].finish_reason == 3: + logger.debug( + f"LLM refused to generate content for safety reasons - {messages}.") + else: + logger.exception(f"{self} error: {e}") + + except Exception as e: + logger.exception(f"{self} exception: {e}") + finally: + await self.push_frame(LLMFullResponseEndFrame()) + + async def process_frame(self, frame: Frame, direction: FrameDirection): + await super().process_frame(frame, direction) + + context = None + + if isinstance(frame, OpenAILLMContextFrame): + context: OpenAILLMContext = frame.context + elif isinstance(frame, LLMMessagesFrame): + context = OpenAILLMContext.from_messages(frame.messages) + elif isinstance(frame, VisionImageRawFrame): + context = OpenAILLMContext.from_image_frame(frame) + else: + await self.push_frame(frame, direction) + + if context: + await self._process_context(context) diff --git a/pipecat/services/moondream.py b/pipecat/services/moondream.py new file mode 100644 index 0000000000000000000000000000000000000000..eddc05374bec3c441134b85a30f529238fc5e6bb --- /dev/null +++ b/pipecat/services/moondream.py @@ -0,0 +1,92 @@ +# +# Copyright (c) 2024, Daily +# +# SPDX-License-Identifier: BSD 2-Clause License +# + +import asyncio + +from PIL import Image + +from typing import AsyncGenerator + +from pipecat.frames.frames import ErrorFrame, Frame, TextFrame, VisionImageRawFrame +from pipecat.services.ai_services import VisionService + +from loguru import logger + +try: + import torch + + from transformers import AutoModelForCausalLM, AutoTokenizer +except ModuleNotFoundError as e: + logger.error(f"Exception: {e}") + logger.error("In order to use Moondream, you need to `pip install pipecat-ai[moondream]`.") + raise Exception(f"Missing module(s): {e}") + + +def detect_device(): + """ + Detects the appropriate device to run on, and return the device and dtype. + """ + try: + import intel_extension_for_pytorch + if torch.xpu.is_available(): + return torch.device("xpu"), torch.float32 + except ImportError: + pass + if torch.cuda.is_available(): + return torch.device("cuda"), torch.float16 + elif torch.backends.mps.is_available(): + return torch.device("mps"), torch.float16 + else: + return torch.device("cpu"), torch.float32 + + +class MoondreamService(VisionService): + def __init__( + self, + *, + model="vikhyatk/moondream2", + revision="2024-04-02", + use_cpu=False + ): + super().__init__() + + if not use_cpu: + device, dtype = detect_device() + else: + device = torch.device("cpu") + dtype = torch.float32 + + self._tokenizer = AutoTokenizer.from_pretrained(model, revision=revision) + + logger.debug("Loading Moondream model...") + + self._model = AutoModelForCausalLM.from_pretrained( + model, trust_remote_code=True, revision=revision + ).to(device=device, dtype=dtype) + self._model.eval() + + logger.debug("Loaded Moondream model") + + async def run_vision(self, frame: VisionImageRawFrame) -> AsyncGenerator[Frame, None]: + if not self._model: + logger.error(f"{self} error: Moondream model not available") + yield ErrorFrame("Moondream model not available") + return + + logger.debug(f"Analyzing image: {frame}") + + def get_image_description(frame: VisionImageRawFrame): + image = Image.frombytes(frame.format, frame.size, frame.image) + image_embeds = self._model.encode_image(image) + description = self._model.answer_question( + image_embeds=image_embeds, + question=frame.text, + tokenizer=self._tokenizer) + return description + + description = await asyncio.to_thread(get_image_description, frame) + + yield TextFrame(text=description) diff --git a/pipecat/services/ollama.py b/pipecat/services/ollama.py new file mode 100644 index 0000000000000000000000000000000000000000..b12cf1980e56f60787c4f0a4e4eece587d52578b --- /dev/null +++ b/pipecat/services/ollama.py @@ -0,0 +1,13 @@ +# +# Copyright (c) 2024, Daily +# +# SPDX-License-Identifier: BSD 2-Clause License +# + +from pipecat.services.openai import BaseOpenAILLMService + + +class OLLamaLLMService(BaseOpenAILLMService): + + def __init__(self, *, model: str = "llama2", base_url: str = "http://localhost:11434/v1"): + super().__init__(model=model, base_url=base_url, api_key="ollama") diff --git a/pipecat/services/openai.py b/pipecat/services/openai.py new file mode 100644 index 0000000000000000000000000000000000000000..cb53261afa1bf03f10c8c3184ebf420e058f2600 --- /dev/null +++ b/pipecat/services/openai.py @@ -0,0 +1,338 @@ +# +# Copyright (c) 2024, Daily +# +# SPDX-License-Identifier: BSD 2-Clause License +# + +import aiohttp +import base64 +import io +import json + +from typing import AsyncGenerator, List, Literal + +from loguru import logger +from PIL import Image + +from pipecat.frames.frames import ( + AudioRawFrame, + ErrorFrame, + Frame, + LLMFullResponseEndFrame, + LLMFullResponseStartFrame, + LLMMessagesFrame, + LLMResponseEndFrame, + LLMResponseStartFrame, + TextFrame, + URLImageRawFrame, + VisionImageRawFrame +) +from pipecat.processors.aggregators.openai_llm_context import ( + OpenAILLMContext, + OpenAILLMContextFrame +) +from pipecat.processors.frame_processor import FrameDirection +from pipecat.services.ai_services import ( + ImageGenService, + LLMService, + TTSService +) + +try: + from openai import AsyncOpenAI, AsyncStream, BadRequestError + from openai.types.chat import ( + ChatCompletionChunk, + ChatCompletionFunctionMessageParam, + ChatCompletionMessageParam, + ChatCompletionToolParam + ) +except ModuleNotFoundError as e: + logger.error(f"Exception: {e}") + logger.error( + "In order to use OpenAI, you need to `pip install pipecat-ai[openai]`. Also, set `OPENAI_API_KEY` environment variable.") + raise Exception(f"Missing module: {e}") + + +class OpenAIUnhandledFunctionException(Exception): + pass + + +class BaseOpenAILLMService(LLMService): + """This is the base for all services that use the AsyncOpenAI client. + + This service consumes OpenAILLMContextFrame frames, which contain a reference + to an OpenAILLMContext frame. The OpenAILLMContext object defines the context + sent to the LLM for a completion. This includes user, assistant and system messages + as well as tool choices and the tool, which is used if requesting function + calls from the LLM. + """ + + def __init__(self, *, model: str, api_key=None, base_url=None, **kwargs): + super().__init__(**kwargs) + self._model: str = model + self._client = self.create_client(api_key=api_key, base_url=base_url, **kwargs) + + def create_client(self, api_key=None, base_url=None, **kwargs): + return AsyncOpenAI(api_key=api_key, base_url=base_url) + + def can_generate_metrics(self) -> bool: + return True + + async def get_chat_completions( + self, + context: OpenAILLMContext, + messages: List[ChatCompletionMessageParam]) -> AsyncStream[ChatCompletionChunk]: + chunks = await self._client.chat.completions.create( + model=self._model, + stream=True, + messages=messages, + tools=context.tools, + tool_choice=context.tool_choice, + ) + return chunks + + async def _stream_chat_completions( + self, context: OpenAILLMContext) -> AsyncStream[ChatCompletionChunk]: + logger.debug(f"Generating chat: {context.get_messages_json()}") + + messages: List[ChatCompletionMessageParam] = context.get_messages() + + # base64 encode any images + for message in messages: + if message.get("mime_type") == "image/jpeg": + encoded_image = base64.b64encode(message["data"].getvalue()).decode("utf-8") + text = message["content"] + message["content"] = [ + {"type": "text", "text": text}, + {"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{encoded_image}"}} + ] + del message["data"] + del message["mime_type"] + + chunks = await self.get_chat_completions(context, messages) + + return chunks + + async def _process_context(self, context: OpenAILLMContext): + function_name = "" + arguments = "" + tool_call_id = "" + + await self.start_ttfb_metrics() + + chunk_stream: AsyncStream[ChatCompletionChunk] = ( + await self._stream_chat_completions(context) + ) + + async for chunk in chunk_stream: + if len(chunk.choices) == 0: + continue + + await self.stop_ttfb_metrics() + + if chunk.choices[0].delta.tool_calls: + # We're streaming the LLM response to enable the fastest response times. + # For text, we just yield each chunk as we receive it and count on consumers + # to do whatever coalescing they need (eg. to pass full sentences to TTS) + # + # If the LLM is a function call, we'll do some coalescing here. + # If the response contains a function name, we'll yield a frame to tell consumers + # that they can start preparing to call the function with that name. + # We accumulate all the arguments for the rest of the streamed response, then when + # the response is done, we package up all the arguments and the function name and + # yield a frame containing the function name and the arguments. + + tool_call = chunk.choices[0].delta.tool_calls[0] + if tool_call.function and tool_call.function.name: + function_name += tool_call.function.name + tool_call_id = tool_call.id + await self.call_start_function(function_name) + if tool_call.function and tool_call.function.arguments: + # Keep iterating through the response to collect all the argument fragments + arguments += tool_call.function.arguments + elif chunk.choices[0].delta.content: + await self.push_frame(LLMResponseStartFrame()) + await self.push_frame(TextFrame(chunk.choices[0].delta.content)) + await self.push_frame(LLMResponseEndFrame()) + + # if we got a function name and arguments, check to see if it's a function with + # a registered handler. If so, run the registered callback, save the result to + # the context, and re-prompt to get a chat answer. If we don't have a registered + # handler, raise an exception. + if function_name and arguments: + if self.has_function(function_name): + await self._handle_function_call(context, tool_call_id, function_name, arguments) + else: + raise OpenAIUnhandledFunctionException( + f"The LLM tried to call a function named '{function_name}', but there isn't a callback registered for that function.") + + async def _handle_function_call( + self, + context, + tool_call_id, + function_name, + arguments + ): + arguments = json.loads(arguments) + result = await self.call_function(function_name, arguments) + arguments = json.dumps(arguments) + if isinstance(result, (str, dict)): + # Handle it in "full magic mode" + tool_call = ChatCompletionFunctionMessageParam({ + "role": "assistant", + "tool_calls": [ + { + "id": tool_call_id, + "function": { + "arguments": arguments, + "name": function_name + }, + "type": "function" + } + ] + + }) + context.add_message(tool_call) + if isinstance(result, dict): + result = json.dumps(result) + tool_result = ChatCompletionToolParam({ + "tool_call_id": tool_call_id, + "role": "tool", + "content": result + }) + context.add_message(tool_result) + # re-prompt to get a human answer + await self._process_context(context) + elif isinstance(result, list): + # reduced magic + for msg in result: + context.add_message(msg) + await self._process_context(context) + elif isinstance(result, type(None)): + pass + else: + raise TypeError(f"Unknown return type from function callback: {type(result)}") + + async def process_frame(self, frame: Frame, direction: FrameDirection): + await super().process_frame(frame, direction) + + context = None + if isinstance(frame, OpenAILLMContextFrame): + context: OpenAILLMContext = frame.context + elif isinstance(frame, LLMMessagesFrame): + context = OpenAILLMContext.from_messages(frame.messages) + elif isinstance(frame, VisionImageRawFrame): + context = OpenAILLMContext.from_image_frame(frame) + else: + await self.push_frame(frame, direction) + + if context: + await self.push_frame(LLMFullResponseStartFrame()) + await self.start_processing_metrics() + await self._process_context(context) + await self.stop_processing_metrics() + await self.push_frame(LLMFullResponseEndFrame()) + + +class OpenAILLMService(BaseOpenAILLMService): + + def __init__(self, *, model: str = "gpt-4o", **kwargs): + super().__init__(model=model, **kwargs) + + +class OpenAIImageGenService(ImageGenService): + + def __init__( + self, + *, + image_size: Literal["256x256", "512x512", "1024x1024", "1792x1024", "1024x1792"], + aiohttp_session: aiohttp.ClientSession, + api_key: str, + model: str = "dall-e-3", + ): + super().__init__() + self._model = model + self._image_size = image_size + self._client = AsyncOpenAI(api_key=api_key) + self._aiohttp_session = aiohttp_session + + async def run_image_gen(self, prompt: str) -> AsyncGenerator[Frame, None]: + logger.debug(f"Generating image from prompt: {prompt}") + + image = await self._client.images.generate( + prompt=prompt, + model=self._model, + n=1, + size=self._image_size + ) + + image_url = image.data[0].url + + if not image_url: + logger.error(f"{self} No image provided in response: {image}") + yield ErrorFrame("Image generation failed") + return + + # Load the image from the url + async with self._aiohttp_session.get(image_url) as response: + image_stream = io.BytesIO(await response.content.read()) + image = Image.open(image_stream) + frame = URLImageRawFrame(image_url, image.tobytes(), image.size, image.format) + yield frame + + +class OpenAITTSService(TTSService): + """This service uses the OpenAI TTS API to generate audio from text. + The returned audio is PCM encoded at 24kHz. When using the DailyTransport, set the sample rate in the DailyParams accordingly: + ``` + DailyParams( + audio_out_enabled=True, + audio_out_sample_rate=24_000, + ) + ``` + """ + + def __init__( + self, + *, + api_key: str | None = None, + base_url: str | None = None, + sample_rate: int = 24_000, + voice: Literal["alloy", "echo", "fable", "onyx", "nova", "shimmer"] = "alloy", + model: Literal["tts-1", "tts-1-hd"] = "tts-1", + **kwargs): + super().__init__(**kwargs) + + self._voice = voice + self._model = model + self.sample_rate=sample_rate + self._client = AsyncOpenAI(api_key=api_key,base_url=base_url) + + def can_generate_metrics(self) -> bool: + return True + + async def run_tts(self, text: str) -> AsyncGenerator[Frame, None]: + logger.debug(f"Generating TTS: [{text}]") + + try: + await self.start_ttfb_metrics() + + async with self._client.audio.speech.with_streaming_response.create( + input=text, + model=self._model, + voice=self._voice, + response_format="pcm", + ) as r: + if r.status_code != 200: + error = await r.text() + logger.error( + f"{self} error getting audio (status: {r.status_code}, error: {error})") + yield ErrorFrame(f"Error getting audio (status: {r.status_code}, error: {error})") + return + async for chunk in r.iter_bytes(8192): + if len(chunk) > 0: + await self.stop_ttfb_metrics() + frame = AudioRawFrame(chunk, self.sample_rate, 1) + yield frame + except BadRequestError as e: + logger.exception(f"{self} error generating TTS: {e}") diff --git a/pipecat/services/openpipe.py b/pipecat/services/openpipe.py new file mode 100644 index 0000000000000000000000000000000000000000..77d9b112add58c4c920726fe87df32c6dc20d4cd --- /dev/null +++ b/pipecat/services/openpipe.py @@ -0,0 +1,71 @@ +# +# Copyright (c) 2024, Daily +# +# SPDX-License-Identifier: BSD 2-Clause License +# + +from typing import Dict, List + +from pipecat.processors.aggregators.openai_llm_context import OpenAILLMContext +from pipecat.services.openai import BaseOpenAILLMService + +from loguru import logger + +try: + from openpipe import AsyncOpenAI as OpenPipeAI, AsyncStream + from openai.types.chat import (ChatCompletionMessageParam, ChatCompletionChunk) +except ModuleNotFoundError as e: + logger.error(f"Exception: {e}") + logger.error( + "In order to use OpenPipe, you need to `pip install pipecat-ai[openpipe]`. Also, set `OPENPIPE_API_KEY` and `OPENAI_API_KEY` environment variables.") + raise Exception(f"Missing module: {e}") + + +class OpenPipeLLMService(BaseOpenAILLMService): + + def __init__( + self, + *, + model: str = "gpt-4o", + api_key: str | None = None, + base_url: str | None = None, + openpipe_api_key: str | None = None, + openpipe_base_url: str = "https://app.openpipe.ai/api/v1", + tags: Dict[str, str] | None = None, + **kwargs): + super().__init__( + model=model, + api_key=api_key, + base_url=base_url, + openpipe_api_key=openpipe_api_key, + openpipe_base_url=openpipe_base_url, + **kwargs) + self._tags = tags + + def create_client(self, api_key=None, base_url=None, **kwargs): + openpipe_api_key = kwargs.get("openpipe_api_key") or "" + openpipe_base_url = kwargs.get("openpipe_base_url") or "" + client = OpenPipeAI( + api_key=api_key, + base_url=base_url, + openpipe={ + "api_key": openpipe_api_key, + "base_url": openpipe_base_url + } + ) + return client + + async def get_chat_completions( + self, + context: OpenAILLMContext, + messages: List[ChatCompletionMessageParam]) -> AsyncStream[ChatCompletionChunk]: + chunks = await self._client.chat.completions.create( + model=self._model, + stream=True, + messages=messages, + openpipe={ + "tags": self._tags, + "log_request": True + } + ) + return chunks diff --git a/pipecat/services/playht.py b/pipecat/services/playht.py new file mode 100644 index 0000000000000000000000000000000000000000..e61daeb27db25d8c301264ba7078b2351f6e496a --- /dev/null +++ b/pipecat/services/playht.py @@ -0,0 +1,83 @@ +# +# Copyright (c) 2024, Daily +# +# SPDX-License-Identifier: BSD 2-Clause License +# + +import io +import struct + +from typing import AsyncGenerator + +from pipecat.frames.frames import AudioRawFrame, Frame +from pipecat.services.ai_services import TTSService + +from loguru import logger + +try: + from pyht.client import TTSOptions + from pyht.async_client import AsyncClient + from pyht.protos.api_pb2 import Format +except ModuleNotFoundError as e: + logger.error(f"Exception: {e}") + logger.error( + "In order to use PlayHT, you need to `pip install pipecat-ai[playht]`. Also, set `PLAY_HT_USER_ID` and `PLAY_HT_API_KEY` environment variables.") + raise Exception(f"Missing module: {e}") + + +class PlayHTTTSService(TTSService): + + def __init__(self, *, api_key: str, user_id: str, voice_url: str, **kwargs): + super().__init__(**kwargs) + + self._user_id = user_id + self._speech_key = api_key + + self._client = AsyncClient( + user_id=self._user_id, + api_key=self._speech_key, + ) + self._options = TTSOptions( + voice=voice_url, + sample_rate=16000, + quality="higher", + format=Format.FORMAT_WAV) + + def can_generate_metrics(self) -> bool: + return True + + async def run_tts(self, text: str) -> AsyncGenerator[Frame, None]: + logger.debug(f"Generating TTS: [{text}]") + + try: + b = bytearray() + in_header = True + + await self.start_ttfb_metrics() + + playht_gen = self._client.tts( + text, + voice_engine="PlayHT2.0-turbo", + options=self._options) + + async for chunk in playht_gen: + # skip the RIFF header. + if in_header: + b.extend(chunk) + if len(b) <= 36: + continue + else: + fh = io.BytesIO(b) + fh.seek(36) + (data, size) = struct.unpack('<4sI', fh.read(8)) + while data != b'data': + fh.read(size) + (data, size) = struct.unpack('<4sI', fh.read(8)) + in_header = False + else: + if len(chunk): + await self.stop_ttfb_metrics() + frame = AudioRawFrame(chunk, 16000, 1) + yield frame + except Exception as e: + logger.exception(f"{self} error generating TTS: {e}") diff --git a/pipecat/services/to_be_updated/__init__.py b/pipecat/services/to_be_updated/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/pipecat/services/to_be_updated/cloudflare_ai_service.py b/pipecat/services/to_be_updated/cloudflare_ai_service.py new file mode 100644 index 0000000000000000000000000000000000000000..ca22919d8850299029a4e808ff0e238b8bca4262 --- /dev/null +++ b/pipecat/services/to_be_updated/cloudflare_ai_service.py @@ -0,0 +1,71 @@ +import requests +import os +from services.ai_service import AIService + +# Note that Cloudflare's AI workers are still in beta. +# https://developers.cloudflare.com/workers-ai/ + + +class CloudflareAIService(AIService): + def __init__(self): + super().__init__() + self.cloudflare_account_id = os.getenv("CLOUDFLARE_ACCOUNT_ID") + self.cloudflare_api_token = os.getenv("CLOUDFLARE_API_TOKEN") + + self.api_base_url = f'https://api.cloudflare.com/client/v4/accounts/{self.cloudflare_account_id}/ai/run/' + self.headers = {"Authorization": f'Bearer {self.cloudflare_api_token}'} + + # base endpoint, used by the others + def run(self, model, input): + response = requests.post( + f"{self.api_base_url}{model}", + headers=self.headers, + json=input) + return response.json() + + # https://developers.cloudflare.com/workers-ai/models/llm/ + def run_llm(self, messages, latest_user_message=None, stream=True): + input = { + "messages": [ + {"role": "system", "content": "You are a friendly assistant"}, + {"role": "user", "content": sentence} + ] + } + + return self.run("@cf/meta/llama-2-7b-chat-int8", input) + + # https://developers.cloudflare.com/workers-ai/models/translation/ + def run_text_translation(self, sentence, source_language, target_language): + return self.run('@cf/meta/m2m100-1.2b', { + "text": sentence, + "source_lang": source_language, + "target_lang": target_language + }) + + # https://developers.cloudflare.com/workers-ai/models/sentiment-analysis/ + def run_text_sentiment(self, sentence): + return self.run("@cf/huggingface/distilbert-sst-2-int8", + {"text": sentence}) + + # https://developers.cloudflare.com/workers-ai/models/image-classification/ + def run_image_classification(self, image_url): + response = requests.get(image_url) + + if response.status_code != 200: + return {"error": "There was a problem downloading the image."} + + if response.status_code == 200: + data = response.content + inputs = {"image": list(data)} + + return self.run("@cf/microsoft/resnet-50", inputs) + + # https://developers.cloudflare.com/workers-ai/models/embedding/ + def run_embeddings(self, texts, size="medium"): + models = { + "small": "@cf/baai/bge-small-en-v1.5", # 384 output dimensions + "medium": "@cf/baai/bge-base-en-v1.5", # 768 output dimensions + "large": "@cf/baai/bge-large-en-v1.5" # 1024 output dimensions + } + + return self.run(models[size], {"text": texts}) diff --git a/pipecat/services/to_be_updated/google_ai_service.py b/pipecat/services/to_be_updated/google_ai_service.py new file mode 100644 index 0000000000000000000000000000000000000000..b0889adf309e6cec92bff63cc9bf34753703c15d --- /dev/null +++ b/pipecat/services/to_be_updated/google_ai_service.py @@ -0,0 +1,31 @@ +from services.ai_service import AIService +import openai +import os + +# To use Google Cloud's AI products, you'll need to install Google Cloud +# CLI and enable the TTS and in your project: +# https://cloud.google.com/sdk/docs/install +from google.cloud import texttospeech + + +class GoogleAIService(AIService): + def __init__(self): + super().__init__() + + self.client = texttospeech.TextToSpeechClient() + self.voice = texttospeech.VoiceSelectionParams( + language_code="en-GB", name="en-GB-Neural2-F" + ) + + self.audio_config = texttospeech.AudioConfig( + audio_encoding=texttospeech.AudioEncoding.LINEAR16, + sample_rate_hertz=16000 + ) + + def run_tts(self, sentence): + synthesis_input = texttospeech.SynthesisInput(text=sentence.strip()) + result = self.client.synthesize_speech( + input=synthesis_input, + voice=self.voice, + audio_config=self.audio_config) + return result diff --git a/pipecat/services/to_be_updated/huggingface_ai_service.py b/pipecat/services/to_be_updated/huggingface_ai_service.py new file mode 100644 index 0000000000000000000000000000000000000000..6e6132cd9300feb4974e9ad495cf73e070daed9c --- /dev/null +++ b/pipecat/services/to_be_updated/huggingface_ai_service.py @@ -0,0 +1,33 @@ +from services.ai_service import AIService +from transformers import pipeline + +# These functions are just intended for testing, not production use. If +# you'd like to use HuggingFace, you should use your own models, or do +# some research into the specific models that will work best for your use +# case. + + +class HuggingFaceAIService(AIService): + def __init__(self): + super().__init__() + + def run_text_sentiment(self, sentence): + classifier = pipeline("sentiment-analysis") + return classifier(sentence) + + # available models at https://huggingface.co/Helsinki-NLP (**not all + # models use 2-character language codes**) + def run_text_translation(self, sentence, source_language, target_language): + translator = pipeline( + f"translation", + model=f"Helsinki-NLP/opus-mt-{source_language}-{target_language}") + + return translator(sentence)[0]["translation_text"] + + def run_text_summarization(self, sentence): + summarizer = pipeline("summarization") + return summarizer(sentence) + + def run_image_classification(self, image_path): + classifier = pipeline("image-classification") + return classifier(image_path) diff --git a/pipecat/services/to_be_updated/mock_ai_service.py b/pipecat/services/to_be_updated/mock_ai_service.py new file mode 100644 index 0000000000000000000000000000000000000000..597bf20242dcd4654c7984bb1c88abda18409428 --- /dev/null +++ b/pipecat/services/to_be_updated/mock_ai_service.py @@ -0,0 +1,27 @@ +import io +import requests +import time +from PIL import Image +from services.ai_service import AIService + + +class MockAIService(AIService): + def __init__(self): + super().__init__() + + def run_tts(self, sentence): + print("running tts", sentence) + time.sleep(2) + + def run_image_gen(self, sentence): + image_url = "https://d3d00swyhr67nd.cloudfront.net/w800h800/collection/ASH/ASHM/ASH_ASHM_WA1940_2_22-001.jpg" + response = requests.get(image_url) + image_stream = io.BytesIO(response.content) + image = Image.open(image_stream) + time.sleep(1) + return (image_url, image.tobytes(), image.size) + + def run_llm(self, messages, latest_user_message=None, stream=True): + for i in range(5): + time.sleep(1) + yield ({"choices": [{"delta": {"content": f"hello {i}!"}}]}) diff --git a/pipecat/services/whisper.py b/pipecat/services/whisper.py new file mode 100644 index 0000000000000000000000000000000000000000..3a7ba242d24bd141ef13f8b2b30dd9c0e2bb2c4a --- /dev/null +++ b/pipecat/services/whisper.py @@ -0,0 +1,94 @@ +# +# Copyright (c) 2024, Daily +# +# SPDX-License-Identifier: BSD 2-Clause License +# + +"""This module implements Whisper transcription with a locally-downloaded model.""" + +import asyncio +import time + +from enum import Enum +from typing_extensions import AsyncGenerator + +import numpy as np + +from pipecat.frames.frames import ErrorFrame, Frame, TranscriptionFrame +from pipecat.services.ai_services import STTService + +from loguru import logger + +try: + from faster_whisper import WhisperModel +except ModuleNotFoundError as e: + logger.error(f"Exception: {e}") + logger.error( + "In order to use Whisper, you need to `pip install pipecat-ai[whisper]`.") + raise Exception(f"Missing module: {e}") + + +class Model(Enum): + """Class of basic Whisper model selection options""" + TINY = "tiny" + BASE = "base" + MEDIUM = "medium" + LARGE = "large-v3" + DISTIL_LARGE_V2 = "Systran/faster-distil-whisper-large-v2" + DISTIL_MEDIUM_EN = "Systran/faster-distil-whisper-medium.en" + + +class WhisperSTTService(STTService): + """Class to transcribe audio with a locally-downloaded Whisper model""" + + def __init__(self, + *, + model: str | Model = Model.DISTIL_MEDIUM_EN, + device: str = "auto", + compute_type: str = "default", + no_speech_prob: float = 0.4, + **kwargs): + + super().__init__(**kwargs) + self._device: str = device + self._compute_type = compute_type + self._model_name: str | Model = model + self._no_speech_prob = no_speech_prob + self._model: WhisperModel | None = None + self._load() + + def can_generate_metrics(self) -> bool: + return True + + def _load(self): + """Loads the Whisper model. Note that if this is the first time + this model is being run, it will take time to download.""" + logger.debug("Loading Whisper model...") + self._model = WhisperModel( + self._model_name.value if isinstance(self._model_name, Enum) else self._model_name, + device=self._device, + compute_type=self._compute_type) + logger.debug("Loaded Whisper model") + + async def run_stt(self, audio: bytes) -> AsyncGenerator[Frame, None]: + """Transcribes given audio using Whisper""" + if not self._model: + logger.error(f"{self} error: Whisper model not available") + yield ErrorFrame("Whisper model not available") + return + + await self.start_ttfb_metrics() + + # Divide by 32768 because we have signed 16-bit data. + audio_float = np.frombuffer(audio, dtype=np.int16).astype(np.float32) / 32768.0 + + segments, _ = await asyncio.to_thread(self._model.transcribe, audio_float) + text: str = "" + for segment in segments: + if segment.no_speech_prob < self._no_speech_prob: + text += f"{segment.text} " + + if text: + await self.stop_ttfb_metrics() + logger.debug(f"Transcription: [{text}]") + yield TranscriptionFrame(text, "", int(time.time_ns() / 1000000)) diff --git a/pipecat/services/xtts.py b/pipecat/services/xtts.py new file mode 100644 index 0000000000000000000000000000000000000000..887d6c2baea550b7718d217cc0f79c9881e71234 --- /dev/null +++ b/pipecat/services/xtts.py @@ -0,0 +1,112 @@ +# +# Copyright (c) 2024, Daily +# +# SPDX-License-Identifier: BSD 2-Clause License +# + +import aiohttp + +from typing import AsyncGenerator + +from pipecat.frames.frames import AudioRawFrame, ErrorFrame, Frame +from pipecat.services.ai_services import TTSService + +from loguru import logger + +import requests + +import numpy as np + +try: + import resampy +except ModuleNotFoundError as e: + logger.error(f"Exception: {e}") + logger.error("In order to use XTTS, you need to `pip install pipecat-ai[xtts]`.") + raise Exception(f"Missing module: {e}") + + +# The server below can connect to XTTS through a local running docker +# +# Docker command: $ docker run --gpus=all -e COQUI_TOS_AGREED=1 --rm -p 8000:80 ghcr.io/coqui-ai/xtts-streaming-server:latest-cuda121 +# +# You can find more information on the official repo: +# https://github.com/coqui-ai/xtts-streaming-server + + +class XTTSService(TTSService): + + def __init__( + self, + *, + aiohttp_session: aiohttp.ClientSession, + voice_id: str, + language: str, + base_url: str, + **kwargs): + super().__init__(**kwargs) + + self._voice_id = voice_id + self._language = language + self._base_url = base_url + self._aiohttp_session = aiohttp_session + self._studio_speakers = requests.get(self._base_url + "/studio_speakers").json() + + def can_generate_metrics(self) -> bool: + return True + + async def run_tts(self, text: str) -> AsyncGenerator[Frame, None]: + logger.debug(f"Generating TTS: [{text}]") + embeddings = self._studio_speakers[self._voice_id] + + url = self._base_url + "/tts_stream" + + payload = { + "text": text.replace('.', '').replace('*', ''), + "language": self._language, + "speaker_embedding": embeddings["speaker_embedding"], + "gpt_cond_latent": embeddings["gpt_cond_latent"], + "add_wav_header": False, + "stream_chunk_size": 20, + } + + await self.start_ttfb_metrics() + + async with self._aiohttp_session.post(url, json=payload) as r: + if r.status != 200: + text = await r.text() + logger.error(f"{self} error getting audio (status: {r.status}, error: {text})") + yield ErrorFrame(f"Error getting audio (status: {r.status}, error: {text})") + return + + buffer = bytearray() + + async for chunk in r.content.iter_chunked(1024): + if len(chunk) > 0: + await self.stop_ttfb_metrics() + # Append new chunk to the buffer + buffer.extend(chunk) + + # Check if buffer has enough data for processing + while len(buffer) >= 48000: # Assuming at least 0.5 seconds of audio data at 24000 Hz + # Process the buffer up to a safe size for resampling + process_data = buffer[:48000] + # Remove processed data from buffer + buffer = buffer[48000:] + + # Convert the byte data to numpy array for resampling + audio_np = np.frombuffer(process_data, dtype=np.int16) + # Resample the audio from 24000 Hz to 16000 Hz + resampled_audio = resampy.resample(audio_np, 24000, 16000) + # Convert the numpy array back to bytes + resampled_audio_bytes = resampled_audio.astype(np.int16).tobytes() + # Create the frame with the resampled audio + frame = AudioRawFrame(resampled_audio_bytes, 16000, 1) + yield frame + + # Process any remaining data in the buffer + if len(buffer) > 0: + audio_np = np.frombuffer(buffer, dtype=np.int16) + resampled_audio = resampy.resample(audio_np, 24000, 16000) + resampled_audio_bytes = resampled_audio.astype(np.int16).tobytes() + frame = AudioRawFrame(resampled_audio_bytes, 16000, 1) + yield frame diff --git a/pipecat/transports/__init__.py b/pipecat/transports/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/pipecat/transports/base_input.py b/pipecat/transports/base_input.py new file mode 100644 index 0000000000000000000000000000000000000000..6a6a2beb23286e96b937941e468b5b99a2042358 --- /dev/null +++ b/pipecat/transports/base_input.py @@ -0,0 +1,177 @@ +# +# Copyright (c) 2024, Daily +# +# SPDX-License-Identifier: BSD 2-Clause License +# + +import asyncio + +from concurrent.futures import ThreadPoolExecutor + +from pipecat.processors.frame_processor import FrameDirection, FrameProcessor +from pipecat.frames.frames import ( + AudioRawFrame, + CancelFrame, + StartFrame, + EndFrame, + Frame, + StartInterruptionFrame, + StopInterruptionFrame, + UserStartedSpeakingFrame, + UserStoppedSpeakingFrame) +from pipecat.transports.base_transport import TransportParams +from pipecat.vad.vad_analyzer import VADAnalyzer, VADState + +from loguru import logger + + +class BaseInputTransport(FrameProcessor): + + def __init__(self, params: TransportParams, **kwargs): + super().__init__(**kwargs) + + self._params = params + + self._executor = ThreadPoolExecutor(max_workers=5) + + # Create push frame task. This is the task that will push frames in + # order. We also guarantee that all frames are pushed in the same task. + self._create_push_task() + + async def start(self, frame: StartFrame): + # Create audio input queue and task if needed. + if self._params.audio_in_enabled or self._params.vad_enabled: + self._audio_in_queue = asyncio.Queue() + self._audio_task = self.get_event_loop().create_task(self._audio_task_handler()) + + async def stop(self): + # Wait for the task to finish. + if self._params.audio_in_enabled or self._params.vad_enabled: + self._audio_task.cancel() + await self._audio_task + + def vad_analyzer(self) -> VADAnalyzer | None: + return self._params.vad_analyzer + + async def push_audio_frame(self, frame: AudioRawFrame): + if self._params.audio_in_enabled or self._params.vad_enabled: + await self._audio_in_queue.put(frame) + + # + # Frame processor + # + + async def cleanup(self): + self._push_frame_task.cancel() + await self._push_frame_task + + async def process_frame(self, frame: Frame, direction: FrameDirection): + await super().process_frame(frame, direction) + + if isinstance(frame, CancelFrame): + await self.stop() + # We don't queue a CancelFrame since we want to stop ASAP. + await self.push_frame(frame, direction) + elif isinstance(frame, StartFrame): + await self.start(frame) + await self._internal_push_frame(frame, direction) + elif isinstance(frame, EndFrame): + await self._internal_push_frame(frame, direction) + await self.stop() + else: + await self._internal_push_frame(frame, direction) + + # + # Push frames task + # + + def _create_push_task(self): + loop = self.get_event_loop() + self._push_queue = asyncio.Queue() + self._push_frame_task = loop.create_task(self._push_frame_task_handler()) + + async def _internal_push_frame( + self, + frame: Frame | None, + direction: FrameDirection | None = FrameDirection.DOWNSTREAM): + await self._push_queue.put((frame, direction)) + + async def _push_frame_task_handler(self): + while True: + try: + (frame, direction) = await self._push_queue.get() + await self.push_frame(frame, direction) + except asyncio.CancelledError: + break + + # + # Handle interruptions + # + + async def _handle_interruptions(self, frame: Frame): + if self.interruptions_allowed: + # Make sure we notify about interruptions quickly out-of-band + if isinstance(frame, UserStartedSpeakingFrame): + logger.debug("User started speaking") + # Cancel the task. This will stop pushing frames downstream. + self._push_frame_task.cancel() + await self._push_frame_task + # Push an out-of-band frame (i.e. not using the ordered push + # frame task) to stop everything, specially at the output + # transport. + await self.push_frame(StartInterruptionFrame()) + # Create a new queue and task. + self._create_push_task() + elif isinstance(frame, UserStoppedSpeakingFrame): + logger.debug("User stopped speaking") + await self.push_frame(StopInterruptionFrame()) + await self._internal_push_frame(frame) + + # + # Audio input + # + + async def _vad_analyze(self, audio_frames: bytes) -> VADState: + state = VADState.QUIET + vad_analyzer = self.vad_analyzer() + if vad_analyzer: + state = await self.get_event_loop().run_in_executor( + self._executor, vad_analyzer.analyze_audio, audio_frames) + return state + + async def _handle_vad(self, audio_frames: bytes, vad_state: VADState): + new_vad_state = await self._vad_analyze(audio_frames) + if new_vad_state != vad_state and new_vad_state != VADState.STARTING and new_vad_state != VADState.STOPPING: + frame = None + if new_vad_state == VADState.SPEAKING: + frame = UserStartedSpeakingFrame() + elif new_vad_state == VADState.QUIET: + frame = UserStoppedSpeakingFrame() + + if frame: + await self._handle_interruptions(frame) + + vad_state = new_vad_state + return vad_state + + async def _audio_task_handler(self): + vad_state: VADState = VADState.QUIET + while True: + try: + frame: AudioRawFrame = await self._audio_in_queue.get() + + audio_passthrough = True + + # Check VAD and push event if necessary. We just care about + # changes from QUIET to SPEAKING and vice versa. + if self._params.vad_enabled: + vad_state = await self._handle_vad(frame.audio, vad_state) + audio_passthrough = self._params.vad_audio_passthrough + + # Push audio downstream if passthrough. + if audio_passthrough: + await self._internal_push_frame(frame) + except asyncio.CancelledError: + break + except Exception as e: + logger.exception(f"{self} error reading audio frames: {e}") diff --git a/pipecat/transports/base_output.py b/pipecat/transports/base_output.py new file mode 100644 index 0000000000000000000000000000000000000000..d25025dbe3633035f587b59081a00a501f116eb3 --- /dev/null +++ b/pipecat/transports/base_output.py @@ -0,0 +1,266 @@ +# +# Copyright (c) 2024, Daily + +# +# SPDX-License-Identifier: BSD 2-Clause License +# + +import asyncio +import itertools + +from PIL import Image +from typing import List + +from pipecat.processors.frame_processor import FrameDirection, FrameProcessor +from pipecat.frames.frames import ( + AudioRawFrame, + CancelFrame, + MetricsFrame, + SpriteFrame, + StartFrame, + EndFrame, + Frame, + ImageRawFrame, + StartInterruptionFrame, + StopInterruptionFrame, + SystemFrame, + TransportMessageFrame) +from pipecat.transports.base_transport import TransportParams + +from loguru import logger + + +class BaseOutputTransport(FrameProcessor): + + def __init__(self, params: TransportParams, **kwargs): + super().__init__(**kwargs) + + self._params = params + + # These are the images that we should send to the camera at our desired + # framerate. + self._camera_images = None + + # We will write 20ms audio at a time. If we receive long audio frames we + # will chunk them. This will help with interruption handling. + audio_bytes_10ms = int(self._params.audio_out_sample_rate / 100) * \ + self._params.audio_out_channels * 2 + self._audio_chunk_size = audio_bytes_10ms * 2 + + self._stopped_event = asyncio.Event() + + # Create sink frame task. This is the task that will actually write + # audio or video frames. We write audio/video in a task so we can keep + # generating frames upstream while, for example, the audio is playing. + self._create_sink_task() + + # Create push frame task. This is the task that will push frames in + # order. We also guarantee that all frames are pushed in the same task. + self._create_push_task() + + async def start(self, frame: StartFrame): + # Create media threads queues. + if self._params.camera_out_enabled: + self._camera_out_queue = asyncio.Queue() + self._camera_out_task = self.get_event_loop().create_task(self._camera_out_task_handler()) + + async def stop(self): + # Wait on the threads to finish. + if self._params.camera_out_enabled: + self._camera_out_task.cancel() + await self._camera_out_task + + self._stopped_event.set() + + async def send_message(self, frame: TransportMessageFrame): + pass + + async def send_metrics(self, frame: MetricsFrame): + pass + + async def write_frame_to_camera(self, frame: ImageRawFrame): + pass + + async def write_raw_audio_frames(self, frames: bytes): + pass + + # + # Frame processor + # + + async def cleanup(self): + if self._sink_task: + self._sink_task.cancel() + await self._sink_task + + self._push_frame_task.cancel() + await self._push_frame_task + + async def process_frame(self, frame: Frame, direction: FrameDirection): + await super().process_frame(frame, direction) + + # + # Out-of-band frames like (CancelFrame or StartInterruptionFrame) are + # pushed immediately. Other frames require order so they are put in the + # sink queue. + # + if isinstance(frame, StartFrame): + await self.start(frame) + await self.push_frame(frame, direction) + # EndFrame is managed in the sink queue handler. + elif isinstance(frame, CancelFrame): + await self.stop() + await self.push_frame(frame, direction) + elif isinstance(frame, StartInterruptionFrame) or isinstance(frame, StopInterruptionFrame): + await self._handle_interruptions(frame) + await self.push_frame(frame, direction) + elif isinstance(frame, MetricsFrame): + await self.send_metrics(frame) + await self.push_frame(frame, direction) + elif isinstance(frame, SystemFrame): + await self.push_frame(frame, direction) + elif isinstance(frame, AudioRawFrame): + await self._handle_audio(frame) + else: + await self._sink_queue.put(frame) + + # If we are finishing, wait here until we have stopped, otherwise we might + # close things too early upstream. We need this event because we don't + # know when the internal threads will finish. + if isinstance(frame, CancelFrame) or isinstance(frame, EndFrame): + await self._stopped_event.wait() + + async def _handle_interruptions(self, frame: Frame): + if not self.interruptions_allowed: + return + + if isinstance(frame, StartInterruptionFrame): + # Stop sink task. + self._sink_task.cancel() + await self._sink_task + self._create_sink_task() + # Stop push task. + self._push_frame_task.cancel() + await self._push_frame_task + self._create_push_task() + + async def _handle_audio(self, frame: AudioRawFrame): + audio = frame.audio + for i in range(0, len(audio), self._audio_chunk_size): + chunk = AudioRawFrame(audio[i: i + self._audio_chunk_size], + sample_rate=frame.sample_rate, num_channels=frame.num_channels) + await self._sink_queue.put(chunk) + + def _create_sink_task(self): + loop = self.get_event_loop() + self._sink_queue = asyncio.Queue() + self._sink_task = loop.create_task(self._sink_task_handler()) + + async def _sink_task_handler(self): + # Audio accumlation buffer + buffer = bytearray() + while True: + try: + frame = await self._sink_queue.get() + if isinstance(frame, AudioRawFrame) and self._params.audio_out_enabled: + buffer.extend(frame.audio) + buffer = await self._maybe_send_audio(buffer) + elif isinstance(frame, ImageRawFrame) and self._params.camera_out_enabled: + await self._set_camera_image(frame) + elif isinstance(frame, SpriteFrame) and self._params.camera_out_enabled: + await self._set_camera_images(frame.images) + elif isinstance(frame, TransportMessageFrame): + await self.send_message(frame) + else: + await self._internal_push_frame(frame) + + if isinstance(frame, EndFrame): + await self.stop() + + self._sink_queue.task_done() + except asyncio.CancelledError: + break + except Exception as e: + logger.exception(f"{self} error processing sink queue: {e}") + + # + # Push frames task + # + + def _create_push_task(self): + loop = self.get_event_loop() + self._push_queue = asyncio.Queue() + self._push_frame_task = loop.create_task(self._push_frame_task_handler()) + + async def _internal_push_frame( + self, + frame: Frame | None, + direction: FrameDirection | None = FrameDirection.DOWNSTREAM): + await self._push_queue.put((frame, direction)) + + async def _push_frame_task_handler(self): + while True: + try: + (frame, direction) = await self._push_queue.get() + await self.push_frame(frame, direction) + except asyncio.CancelledError: + break + + # + # Camera out + # + + async def send_image(self, frame: ImageRawFrame | SpriteFrame): + await self.process_frame(frame, FrameDirection.DOWNSTREAM) + + async def _draw_image(self, frame: ImageRawFrame): + desired_size = (self._params.camera_out_width, self._params.camera_out_height) + + if frame.size != desired_size: + image = Image.frombytes(frame.format, frame.size, frame.image) + resized_image = image.resize(desired_size) + logger.warning( + f"{frame} does not have the expected size {desired_size}, resizing") + frame = ImageRawFrame(resized_image.tobytes(), resized_image.size, resized_image.format) + + await self.write_frame_to_camera(frame) + + async def _set_camera_image(self, image: ImageRawFrame): + if self._params.camera_out_is_live: + await self._camera_out_queue.put(image) + else: + self._camera_images = itertools.cycle([image]) + + async def _set_camera_images(self, images: List[ImageRawFrame]): + self._camera_images = itertools.cycle(images) + + async def _camera_out_task_handler(self): + while True: + try: + if self._params.camera_out_is_live: + image = await self._camera_out_queue.get() + await self._draw_image(image) + self._camera_out_queue.task_done() + elif self._camera_images: + image = next(self._camera_images) + await self._draw_image(image) + await asyncio.sleep(1.0 / self._params.camera_out_framerate) + else: + await asyncio.sleep(1.0 / self._params.camera_out_framerate) + except asyncio.CancelledError: + break + except Exception as e: + logger.exception(f"{self} error writing to camera: {e}") + + # + # Audio out + # + + async def send_audio(self, frame: AudioRawFrame): + await self.process_frame(frame, FrameDirection.DOWNSTREAM) + + async def _maybe_send_audio(self, buffer: bytearray) -> bytearray: + if len(buffer) >= self._audio_chunk_size: + await self.write_raw_audio_frames(bytes(buffer[:self._audio_chunk_size])) + buffer = buffer[self._audio_chunk_size:] + return buffer diff --git a/pipecat/transports/base_transport.py b/pipecat/transports/base_transport.py new file mode 100644 index 0000000000000000000000000000000000000000..fcc8a6068c2b706a8a2488915b63a7b3e1a275ae --- /dev/null +++ b/pipecat/transports/base_transport.py @@ -0,0 +1,85 @@ +# +# Copyright (c) 2024, Daily +# +# SPDX-License-Identifier: BSD 2-Clause License +# + +import asyncio +import inspect + +from abc import ABC, abstractmethod + +from pydantic import ConfigDict +from pydantic.main import BaseModel + +from pipecat.processors.frame_processor import FrameProcessor +from pipecat.vad.vad_analyzer import VADAnalyzer + +from loguru import logger + + +class TransportParams(BaseModel): + model_config = ConfigDict(arbitrary_types_allowed=True) + + camera_out_enabled: bool = False + camera_out_is_live: bool = False + camera_out_width: int = 1024 + camera_out_height: int = 768 + camera_out_bitrate: int = 800000 + camera_out_framerate: int = 30 + camera_out_color_format: str = "RGB" + audio_out_enabled: bool = False + audio_out_sample_rate: int = 16000 + audio_out_channels: int = 1 + audio_in_enabled: bool = False + audio_in_sample_rate: int = 16000 + audio_in_channels: int = 1 + vad_enabled: bool = False + vad_audio_passthrough: bool = False + vad_analyzer: VADAnalyzer | None = None + + +class BaseTransport(ABC): + + def __init__(self, + input_name: str | None = None, + output_name: str | None = None, + loop: asyncio.AbstractEventLoop | None = None): + self._input_name = input_name + self._output_name = output_name + self._loop = loop or asyncio.get_running_loop() + self._event_handlers: dict = {} + + @abstractmethod + def input(self) -> FrameProcessor: + raise NotImplementedError + + @abstractmethod + def output(self) -> FrameProcessor: + raise NotImplementedError + + def event_handler(self, event_name: str): + def decorator(handler): + self._add_event_handler(event_name, handler) + return handler + return decorator + + def _register_event_handler(self, event_name: str): + if event_name in self._event_handlers: + raise Exception(f"Event handler {event_name} already registered") + self._event_handlers[event_name] = [] + + def _add_event_handler(self, event_name: str, handler): + if event_name not in self._event_handlers: + raise Exception(f"Event handler {event_name} not registered") + self._event_handlers[event_name].append(handler) + + async def _call_event_handler(self, event_name: str, *args, **kwargs): + try: + for handler in self._event_handlers[event_name]: + if inspect.iscoroutinefunction(handler): + await handler(self, *args, **kwargs) + else: + handler(self, *args, **kwargs) + except Exception as e: + logger.exception(f"Exception in event handler {event_name}: {e}") diff --git a/pipecat/transports/local/__init__.py b/pipecat/transports/local/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/pipecat/transports/local/audio.py b/pipecat/transports/local/audio.py new file mode 100644 index 0000000000000000000000000000000000000000..579ee36f383df3b35781e603aa4db4c6d05a7d4b --- /dev/null +++ b/pipecat/transports/local/audio.py @@ -0,0 +1,116 @@ +# +# Copyright (c) 2024, Daily +# +# SPDX-License-Identifier: BSD 2-Clause License +# + +import asyncio + +from concurrent.futures import ThreadPoolExecutor + +from pipecat.frames.frames import AudioRawFrame, StartFrame +from pipecat.processors.frame_processor import FrameProcessor +from pipecat.transports.base_input import BaseInputTransport +from pipecat.transports.base_output import BaseOutputTransport +from pipecat.transports.base_transport import BaseTransport, TransportParams + +from loguru import logger + +try: + import pyaudio +except ModuleNotFoundError as e: + logger.error(f"Exception: {e}") + logger.error( + "In order to use local audio, you need to `pip install pipecat-ai[local]`. On MacOS, you also need to `brew install portaudio`.") + raise Exception(f"Missing module: {e}") + + +class LocalAudioInputTransport(BaseInputTransport): + + def __init__(self, py_audio: pyaudio.PyAudio, params: TransportParams): + super().__init__(params) + + sample_rate = self._params.audio_in_sample_rate + num_frames = int(sample_rate / 100) * 2 # 20ms of audio + + self._in_stream = py_audio.open( + format=py_audio.get_format_from_width(2), + channels=params.audio_in_channels, + rate=params.audio_in_sample_rate, + frames_per_buffer=num_frames, + stream_callback=self._audio_in_callback, + input=True) + + async def start(self, frame: StartFrame): + await super().start(frame) + self._in_stream.start_stream() + + async def cleanup(self): + await super().cleanup() + self._in_stream.stop_stream() + # This is not very pretty (taken from PyAudio docs). + while self._in_stream.is_active(): + await asyncio.sleep(0.1) + self._in_stream.close() + + def _audio_in_callback(self, in_data, frame_count, time_info, status): + frame = AudioRawFrame(audio=in_data, + sample_rate=self._params.audio_in_sample_rate, + num_channels=self._params.audio_in_channels) + + asyncio.run_coroutine_threadsafe(self.push_audio_frame(frame), self.get_event_loop()) + + return (None, pyaudio.paContinue) + + +class LocalAudioOutputTransport(BaseOutputTransport): + + def __init__(self, py_audio: pyaudio.PyAudio, params: TransportParams): + super().__init__(params) + + self._executor = ThreadPoolExecutor(max_workers=5) + + self._out_stream = py_audio.open( + format=py_audio.get_format_from_width(2), + channels=params.audio_out_channels, + rate=params.audio_out_sample_rate, + output=True) + + async def start(self, frame: StartFrame): + await super().start(frame) + self._out_stream.start_stream() + + async def cleanup(self): + await super().cleanup() + self._out_stream.stop_stream() + # This is not very pretty (taken from PyAudio docs). + while self._out_stream.is_active(): + await asyncio.sleep(0.1) + self._out_stream.close() + + async def write_raw_audio_frames(self, frames: bytes): + await self.get_event_loop().run_in_executor(self._executor, self._out_stream.write, frames) + + +class LocalAudioTransport(BaseTransport): + + def __init__(self, params: TransportParams): + self._params = params + self._pyaudio = pyaudio.PyAudio() + + self._input: LocalAudioInputTransport | None = None + self._output: LocalAudioOutputTransport | None = None + + # + # BaseTransport + # + + def input(self) -> FrameProcessor: + if not self._input: + self._input = LocalAudioInputTransport(self._pyaudio, self._params) + return self._input + + def output(self) -> FrameProcessor: + if not self._output: + self._output = LocalAudioOutputTransport(self._pyaudio, self._params) + return self._output diff --git a/pipecat/transports/local/tk.py b/pipecat/transports/local/tk.py new file mode 100644 index 0000000000000000000000000000000000000000..be58e77cc1e0b639afeee4de5881c69735ae1834 --- /dev/null +++ b/pipecat/transports/local/tk.py @@ -0,0 +1,152 @@ +# +# Copyright (c) 2024, Daily +# +# SPDX-License-Identifier: BSD 2-Clause License +# + +import asyncio + +from concurrent.futures import ThreadPoolExecutor + +import numpy as np +import tkinter as tk + +from pipecat.frames.frames import AudioRawFrame, ImageRawFrame, StartFrame +from pipecat.processors.frame_processor import FrameProcessor +from pipecat.transports.base_input import BaseInputTransport +from pipecat.transports.base_output import BaseOutputTransport +from pipecat.transports.base_transport import BaseTransport, TransportParams + +from loguru import logger + +try: + import pyaudio +except ModuleNotFoundError as e: + logger.error(f"Exception: {e}") + logger.error( + "In order to use local audio, you need to `pip install pipecat-ai[audio]`. On MacOS, you also need to `brew install portaudio`.") + raise Exception(f"Missing module: {e}") + +try: + import tkinter as tk +except ModuleNotFoundError as e: + logger.error(f"Exception: {e}") + logger.error("tkinter missing. Try `apt install python3-tk` or `brew install python-tk@3.10`.") + raise Exception(f"Missing module: {e}") + + +class TkInputTransport(BaseInputTransport): + + def __init__(self, py_audio: pyaudio.PyAudio, params: TransportParams): + super().__init__(params) + + sample_rate = self._params.audio_in_sample_rate + num_frames = int(sample_rate / 100) * 2 # 20ms of audio + + self._in_stream = py_audio.open( + format=py_audio.get_format_from_width(2), + channels=params.audio_in_channels, + rate=params.audio_in_sample_rate, + frames_per_buffer=num_frames, + stream_callback=self._audio_in_callback, + input=True) + + async def start(self, frame: StartFrame): + await super().start(frame) + self._in_stream.start_stream() + + async def cleanup(self): + await super().cleanup() + self._in_stream.stop_stream() + # This is not very pretty (taken from PyAudio docs). + while self._in_stream.is_active(): + await asyncio.sleep(0.1) + self._in_stream.close() + + def _audio_in_callback(self, in_data, frame_count, time_info, status): + frame = AudioRawFrame(audio=in_data, + sample_rate=self._params.audio_in_sample_rate, + num_channels=self._params.audio_in_channels) + + asyncio.run_coroutine_threadsafe(self.push_audio_frame(frame), self.get_event_loop()) + + return (None, pyaudio.paContinue) + + +class TkOutputTransport(BaseOutputTransport): + + def __init__(self, tk_root: tk.Tk, py_audio: pyaudio.PyAudio, params: TransportParams): + super().__init__(params) + + self._executor = ThreadPoolExecutor(max_workers=5) + + self._out_stream = py_audio.open( + format=py_audio.get_format_from_width(2), + channels=params.audio_out_channels, + rate=params.audio_out_sample_rate, + output=True) + + # Start with a neutral gray background. + array = np.ones((1024, 1024, 3)) * 128 + data = f"P5 {1024} {1024} 255 ".encode() + array.astype(np.uint8).tobytes() + photo = tk.PhotoImage(width=1024, height=1024, data=data, format="PPM") + self._image_label = tk.Label(tk_root, image=photo) + self._image_label.pack() + + async def start(self, frame: StartFrame): + await super().start(frame) + self._out_stream.start_stream() + + async def cleanup(self): + await super().cleanup() + self._out_stream.stop_stream() + # This is not very pretty (taken from PyAudio docs). + while self._out_stream.is_active(): + await asyncio.sleep(0.1) + self._out_stream.close() + + async def write_raw_audio_frames(self, frames: bytes): + await self.get_event_loop().run_in_executor(self._executor, self._out_stream.write, frames) + + async def write_frame_to_camera(self, frame: ImageRawFrame): + self.get_event_loop().call_soon(self._write_frame_to_tk, frame) + + def _write_frame_to_tk(self, frame: ImageRawFrame): + width = frame.size[0] + height = frame.size[1] + data = f"P6 {width} {height} 255 ".encode() + frame.image + photo = tk.PhotoImage( + width=width, + height=height, + data=data, + format="PPM") + self._image_label.config(image=photo) + + # This holds a reference to the photo, preventing it from being garbage + # collected. + self._image_label.image = photo + + +class TkLocalTransport(BaseTransport): + + def __init__(self, tk_root: tk.Tk, params: TransportParams): + self._tk_root = tk_root + self._params = params + self._pyaudio = pyaudio.PyAudio() + + self._input: TkInputTransport | None = None + self._output: TkOutputTransport | None = None + + # + # BaseTransport + # + + def input(self) -> FrameProcessor: + if not self._input: + self._input = TkInputTransport(self._pyaudio, self._params) + return self._input + + def output(self) -> FrameProcessor: + if not self._output: + self._output = TkOutputTransport(self._tk_root, self._pyaudio, self._params) + return self._output diff --git a/pipecat/transports/network/__init__.py b/pipecat/transports/network/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/pipecat/transports/network/fastapi_websocket.py b/pipecat/transports/network/fastapi_websocket.py new file mode 100644 index 0000000000000000000000000000000000000000..b244c97503c0f42649064ddb337b41ebfad06f65 --- /dev/null +++ b/pipecat/transports/network/fastapi_websocket.py @@ -0,0 +1,159 @@ +# +# Copyright (c) 2024, Daily +# +# SPDX-License-Identifier: BSD 2-Clause License +# + + +import asyncio +import io +import wave + +from typing import Awaitable, Callable +from pydantic.main import BaseModel + +from pipecat.frames.frames import AudioRawFrame, StartFrame +from pipecat.processors.frame_processor import FrameProcessor +from pipecat.serializers.base_serializer import FrameSerializer +from pipecat.transports.base_input import BaseInputTransport +from pipecat.transports.base_output import BaseOutputTransport +from pipecat.transports.base_transport import BaseTransport, TransportParams + +from loguru import logger + +try: + from fastapi import WebSocket + from starlette.websockets import WebSocketState +except ModuleNotFoundError as e: + logger.error(f"Exception: {e}") + logger.error( + "In order to use FastAPI websockets, you need to `pip install pipecat-ai[websocket]`.") + raise Exception(f"Missing module: {e}") + + +class FastAPIWebsocketParams(TransportParams): + add_wav_header: bool = False + audio_frame_size: int = 6400 # 200ms + serializer: FrameSerializer + + +class FastAPIWebsocketCallbacks(BaseModel): + on_client_connected: Callable[[WebSocket], Awaitable[None]] + on_client_disconnected: Callable[[WebSocket], Awaitable[None]] + + +class FastAPIWebsocketInputTransport(BaseInputTransport): + + def __init__( + self, + websocket: WebSocket, + params: FastAPIWebsocketParams, + callbacks: FastAPIWebsocketCallbacks, + **kwargs): + super().__init__(params, **kwargs) + + self._websocket = websocket + self._params = params + self._callbacks = callbacks + + async def start(self, frame: StartFrame): + await self._callbacks.on_client_connected(self._websocket) + await super().start(frame) + self._receive_task = self.get_event_loop().create_task(self._receive_messages()) + + async def stop(self): + if self._websocket.client_state != WebSocketState.DISCONNECTED: + await self._websocket.close() + await super().stop() + + async def _receive_messages(self): + async for message in self._websocket.iter_text(): + frame = self._params.serializer.deserialize(message) + + if not frame: + continue + + if isinstance(frame, AudioRawFrame): + await self.push_audio_frame(frame) + + await self._callbacks.on_client_disconnected(self._websocket) + + +class FastAPIWebsocketOutputTransport(BaseOutputTransport): + + def __init__(self, websocket: WebSocket, params: FastAPIWebsocketParams, **kwargs): + super().__init__(params, **kwargs) + + self._websocket = websocket + self._params = params + self._audio_buffer = bytes() + + async def write_raw_audio_frames(self, frames: bytes): + self._audio_buffer += frames + while len(self._audio_buffer) >= self._params.audio_frame_size: + frame = AudioRawFrame( + audio=self._audio_buffer[:self._params.audio_frame_size], + sample_rate=self._params.audio_out_sample_rate, + num_channels=self._params.audio_out_channels + ) + + if self._params.add_wav_header: + content = io.BytesIO() + ww = wave.open(content, "wb") + ww.setsampwidth(2) + ww.setnchannels(frame.num_channels) + ww.setframerate(frame.sample_rate) + ww.writeframes(frame.audio) + ww.close() + content.seek(0) + wav_frame = AudioRawFrame( + content.read(), + sample_rate=frame.sample_rate, + num_channels=frame.num_channels) + frame = wav_frame + + payload = self._params.serializer.serialize(frame) + if payload and self._websocket.client_state == WebSocketState.CONNECTED: + await self._websocket.send_text(payload) + + self._audio_buffer = self._audio_buffer[self._params.audio_frame_size:] + + +class FastAPIWebsocketTransport(BaseTransport): + + def __init__( + self, + websocket: WebSocket, + params: FastAPIWebsocketParams, + input_name: str | None = None, + output_name: str | None = None, + loop: asyncio.AbstractEventLoop | None = None): + super().__init__(input_name=input_name, output_name=output_name, loop=loop) + self._params = params + + self._callbacks = FastAPIWebsocketCallbacks( + on_client_connected=self._on_client_connected, + on_client_disconnected=self._on_client_disconnected + ) + + self._input = FastAPIWebsocketInputTransport( + websocket, self._params, self._callbacks, name=self._input_name) + self._output = FastAPIWebsocketOutputTransport( + websocket, self._params, name=self._output_name) + + # Register supported handlers. The user will only be able to register + # these handlers. + self._register_event_handler("on_client_connected") + self._register_event_handler("on_client_disconnected") + + def input(self) -> FrameProcessor: + return self._input + + def output(self) -> FrameProcessor: + return self._output + + async def _on_client_connected(self, websocket): + await self._call_event_handler("on_client_connected", websocket) + + async def _on_client_disconnected(self, websocket): + await self._call_event_handler("on_client_disconnected", websocket) diff --git a/pipecat/transports/network/websocket_server.py b/pipecat/transports/network/websocket_server.py new file mode 100644 index 0000000000000000000000000000000000000000..d713df87ccd2c6f28fc2800989c19f4ab2fd08a1 --- /dev/null +++ b/pipecat/transports/network/websocket_server.py @@ -0,0 +1,211 @@ +# +# Copyright (c) 2024, Daily +# +# SPDX-License-Identifier: BSD 2-Clause License +# + +import asyncio +import io +import wave + +from typing import Awaitable, Callable +from pydantic.main import BaseModel + +from pipecat.frames.frames import AudioRawFrame, StartFrame +from pipecat.processors.frame_processor import FrameProcessor +from pipecat.serializers.base_serializer import FrameSerializer +from pipecat.serializers.protobuf import ProtobufFrameSerializer +from pipecat.transports.base_input import BaseInputTransport +from pipecat.transports.base_output import BaseOutputTransport +from pipecat.transports.base_transport import BaseTransport, TransportParams + +from loguru import logger + +try: + import websockets +except ModuleNotFoundError as e: + logger.error(f"Exception: {e}") + logger.error("In order to use websockets, you need to `pip install pipecat-ai[websocket]`.") + raise Exception(f"Missing module: {e}") + + +class WebsocketServerParams(TransportParams): + add_wav_header: bool = False + audio_frame_size: int = 6400 # 200ms + serializer: FrameSerializer = ProtobufFrameSerializer() + + +class WebsocketServerCallbacks(BaseModel): + on_client_connected: Callable[[websockets.WebSocketServerProtocol], Awaitable[None]] + on_client_disconnected: Callable[[websockets.WebSocketServerProtocol], Awaitable[None]] + + +class WebsocketServerInputTransport(BaseInputTransport): + + def __init__( + self, + host: str, + port: int, + params: WebsocketServerParams, + callbacks: WebsocketServerCallbacks, + **kwargs): + super().__init__(params, **kwargs) + + self._host = host + self._port = port + self._params = params + self._callbacks = callbacks + + self._websocket: websockets.WebSocketServerProtocol | None = None + + self._stop_server_event = asyncio.Event() + + async def start(self, frame: StartFrame): + self._server_task = self.get_event_loop().create_task(self._server_task_handler()) + await super().start(frame) + + async def stop(self): + self._stop_server_event.set() + await self._server_task + await super().stop() + + async def _server_task_handler(self): + logger.info(f"Starting websocket server on {self._host}:{self._port}") + async with websockets.serve(self._client_handler, self._host, self._port) as server: + await self._stop_server_event.wait() + + async def _client_handler(self, websocket: websockets.WebSocketServerProtocol, path): + logger.info(f"New client connection from {websocket.remote_address}") + if self._websocket: + await self._websocket.close() + logger.warning("Only one client connected, using new connection") + + self._websocket = websocket + + # Notify + await self._callbacks.on_client_connected(websocket) + + # Handle incoming messages + async for message in websocket: + frame = self._params.serializer.deserialize(message) + + if not frame: + continue + + if isinstance(frame, AudioRawFrame): + await self.push_audio_frame(frame) + else: + await self._internal_push_frame(frame) + + # Notify disconnection + await self._callbacks.on_client_disconnected(websocket) + + await self._websocket.close() + self._websocket = None + + logger.info(f"Client {websocket.remote_address} disconnected") + + +class WebsocketServerOutputTransport(BaseOutputTransport): + + def __init__(self, params: WebsocketServerParams, **kwargs): + super().__init__(params, **kwargs) + + self._params = params + + self._websocket: websockets.WebSocketServerProtocol | None = None + + self._audio_buffer = bytes() + + async def set_client_connection(self, websocket: websockets.WebSocketServerProtocol | None): + if self._websocket: + await self._websocket.close() + logger.warning("Only one client allowed, using new connection") + self._websocket = websocket + + async def write_raw_audio_frames(self, frames: bytes): + if not self._websocket: + return + + self._audio_buffer += frames + while len(self._audio_buffer) >= self._params.audio_frame_size: + frame = AudioRawFrame( + audio=self._audio_buffer[:self._params.audio_frame_size], + sample_rate=self._params.audio_out_sample_rate, + num_channels=self._params.audio_out_channels + ) + + if self._params.add_wav_header: + content = io.BytesIO() + ww = wave.open(content, "wb") + ww.setsampwidth(2) + ww.setnchannels(frame.num_channels) + ww.setframerate(frame.sample_rate) + ww.writeframes(frame.audio) + ww.close() + content.seek(0) + wav_frame = AudioRawFrame( + content.read(), + sample_rate=frame.sample_rate, + num_channels=frame.num_channels) + frame = wav_frame + + proto = self._params.serializer.serialize(frame) + if proto: + await self._websocket.send(proto) + + self._audio_buffer = self._audio_buffer[self._params.audio_frame_size:] + + +class WebsocketServerTransport(BaseTransport): + + def __init__( + self, + host: str = "localhost", + port: int = 8765, + params: WebsocketServerParams = WebsocketServerParams(), + input_name: str | None = None, + output_name: str | None = None, + loop: asyncio.AbstractEventLoop | None = None): + super().__init__(input_name=input_name, output_name=output_name, loop=loop) + self._host = host + self._port = port + self._params = params + + self._callbacks = WebsocketServerCallbacks( + on_client_connected=self._on_client_connected, + on_client_disconnected=self._on_client_disconnected + ) + self._input: WebsocketServerInputTransport | None = None + self._output: WebsocketServerOutputTransport | None = None + self._websocket: websockets.WebSocketServerProtocol | None = None + + # Register supported handlers. The user will only be able to register + # these handlers. + self._register_event_handler("on_client_connected") + self._register_event_handler("on_client_disconnected") + + def input(self) -> FrameProcessor: + if not self._input: + self._input = WebsocketServerInputTransport( + self._host, self._port, self._params, self._callbacks, name=self._input_name) + return self._input + + def output(self) -> FrameProcessor: + if not self._output: + self._output = WebsocketServerOutputTransport(self._params, name=self._output_name) + return self._output + + async def _on_client_connected(self, websocket): + if self._output: + await self._output.set_client_connection(websocket) + await self._call_event_handler("on_client_connected", websocket) + else: + logger.error("A WebsocketServerTransport output is missing in the pipeline") + + async def _on_client_disconnected(self, websocket): + if self._output: + await self._output.set_client_connection(None) + await self._call_event_handler("on_client_disconnected", websocket) + else: + logger.error("A WebsocketServerTransport output is missing in the pipeline") diff --git a/pipecat/transports/services/__init__.py b/pipecat/transports/services/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/pipecat/transports/services/daily.py b/pipecat/transports/services/daily.py new file mode 100644 index 0000000000000000000000000000000000000000..6b173765fd0624178a86228326f0b3d1c0c9d3a8 --- /dev/null +++ b/pipecat/transports/services/daily.py @@ -0,0 +1,881 @@ +# +# Copyright (c) 2024, Daily +# +# SPDX-License-Identifier: BSD 2-Clause License +# + +import aiohttp +import asyncio +import time + +from dataclasses import dataclass +from typing import Any, Awaitable, Callable, Mapping +from concurrent.futures import ThreadPoolExecutor + +from daily import ( + CallClient, + Daily, + EventHandler, + VirtualCameraDevice, + VirtualMicrophoneDevice, + VirtualSpeakerDevice) +from pydantic.main import BaseModel + +from pipecat.frames.frames import ( + AudioRawFrame, + Frame, + ImageRawFrame, + InterimTranscriptionFrame, + MetricsFrame, + SpriteFrame, + StartFrame, + TranscriptionFrame, + TransportMessageFrame, + UserImageRawFrame, + UserImageRequestFrame) +from pipecat.processors.frame_processor import FrameDirection, FrameProcessor +from pipecat.transports.base_input import BaseInputTransport +from pipecat.transports.base_output import BaseOutputTransport +from pipecat.transports.base_transport import BaseTransport, TransportParams +from pipecat.vad.vad_analyzer import VADAnalyzer, VADParams + +from loguru import logger + +try: + from daily import (EventHandler, CallClient, Daily) +except ModuleNotFoundError as e: + logger.error(f"Exception: {e}") + logger.error( + "In order to use the Daily transport, you need to `pip install pipecat-ai[daily]`.") + raise Exception(f"Missing module: {e}") + +VAD_RESET_PERIOD_MS = 2000 + + +@dataclass +class DailyTransportMessageFrame(TransportMessageFrame): + participant_id: str | None = None + + +class WebRTCVADAnalyzer(VADAnalyzer): + + def __init__(self, *, sample_rate=16000, num_channels=1, params: VADParams = VADParams()): + super().__init__(sample_rate=sample_rate, num_channels=num_channels, params=params) + + self._webrtc_vad = Daily.create_native_vad( + reset_period_ms=VAD_RESET_PERIOD_MS, + sample_rate=sample_rate, + channels=num_channels + ) + logger.debug("Loaded native WebRTC VAD") + + def num_frames_required(self) -> int: + return int(self.sample_rate / 100.0) + + def voice_confidence(self, buffer) -> float: + confidence = 0 + if len(buffer) > 0: + confidence = self._webrtc_vad.analyze_frames(buffer) + return confidence + + +class DailyDialinSettings(BaseModel): + call_id: str = "" + call_domain: str = "" + + +class DailyTranscriptionSettings(BaseModel): + language: str = "en" + tier: str = "nova" + model: str = "2-conversationalai" + profanity_filter: bool = True + redact: bool = False + endpointing: bool = True + punctuate: bool = True + includeRawResponse: bool = True + extra: Mapping[str, Any] = { + "interim_results": True + } + + +class DailyParams(TransportParams): + api_url: str = "https://api.daily.co/v1" + api_key: str = "" + dialin_settings: DailyDialinSettings | None = None + transcription_enabled: bool = False + transcription_settings: DailyTranscriptionSettings = DailyTranscriptionSettings() + + +class DailyCallbacks(BaseModel): + on_joined: Callable[[Mapping[str, Any]], Awaitable[None]] + on_left: Callable[[], Awaitable[None]] + on_error: Callable[[str], Awaitable[None]] + on_app_message: Callable[[Any, str], Awaitable[None]] + on_call_state_updated: Callable[[str], Awaitable[None]] + on_dialin_ready: Callable[[str], Awaitable[None]] + on_dialout_answered: Callable[[Any], Awaitable[None]] + on_dialout_connected: Callable[[Any], Awaitable[None]] + on_dialout_stopped: Callable[[Any], Awaitable[None]] + on_dialout_error: Callable[[Any], Awaitable[None]] + on_dialout_warning: Callable[[Any], Awaitable[None]] + on_first_participant_joined: Callable[[Mapping[str, Any]], Awaitable[None]] + on_participant_joined: Callable[[Mapping[str, Any]], Awaitable[None]] + on_participant_left: Callable[[Mapping[str, Any], str], Awaitable[None]] + + +def completion_callback(future): + def _callback(*args): + if not future.cancelled(): + future.get_loop().call_soon_threadsafe(future.set_result, *args) + return _callback + + +class DailyTransportClient(EventHandler): + + _daily_initialized: bool = False + + # This is necessary to override EventHandler's __new__ method. + def __new__(cls, *args, **kwargs): + return super().__new__(cls) + + def __init__( + self, + room_url: str, + token: str | None, + bot_name: str, + params: DailyParams, + callbacks: DailyCallbacks, + loop: asyncio.AbstractEventLoop): + super().__init__() + + if not self._daily_initialized: + self._daily_initialized = True + Daily.init() + + self._room_url: str = room_url + self._token: str | None = token + self._bot_name: str = bot_name + self._params: DailyParams = params + self._callbacks = callbacks + self._loop = loop + + self._participant_id: str = "" + self._video_renderers = {} + self._transcription_renderers = {} + self._other_participant_has_joined = False + + self._joined = False + self._joining = False + self._leaving = False + + self._executor = ThreadPoolExecutor(max_workers=5) + + self._client: CallClient = CallClient(event_handler=self) + + self._camera: VirtualCameraDevice = Daily.create_camera_device( + "camera", + width=self._params.camera_out_width, + height=self._params.camera_out_height, + color_format=self._params.camera_out_color_format) + + self._mic: VirtualMicrophoneDevice = Daily.create_microphone_device( + "mic", + sample_rate=self._params.audio_out_sample_rate, + channels=self._params.audio_out_channels, + non_blocking=True) + + self._speaker: VirtualSpeakerDevice = Daily.create_speaker_device( + "speaker", + sample_rate=self._params.audio_in_sample_rate, + channels=self._params.audio_in_channels, + non_blocking=True) + Daily.select_speaker_device("speaker") + + @property + def participant_id(self) -> str: + return self._participant_id + + def set_callbacks(self, callbacks: DailyCallbacks): + self._callbacks = callbacks + + async def send_message(self, frame: DailyTransportMessageFrame): + future = self._loop.create_future() + self._client.send_app_message( + frame.message, + frame.participant_id, + completion=completion_callback(future)) + await future + + async def read_next_audio_frame(self) -> AudioRawFrame | None: + sample_rate = self._params.audio_in_sample_rate + num_channels = self._params.audio_in_channels + num_frames = int(sample_rate / 100) * 2 # 20ms of audio + + future = self._loop.create_future() + self._speaker.read_frames(num_frames, completion=completion_callback(future)) + audio = await future + + if len(audio) > 0: + return AudioRawFrame(audio=audio, sample_rate=sample_rate, num_channels=num_channels) + else: + # If we don't read any audio it could be there's no participant + # connected. daily-python will return immediately if that's the + # case, so let's sleep for a little bit (i.e. busy wait). + await asyncio.sleep(0.01) + return None + + async def write_raw_audio_frames(self, frames: bytes): + future = self._loop.create_future() + self._mic.write_frames(frames, completion=completion_callback(future)) + await future + + async def write_frame_to_camera(self, frame: ImageRawFrame): + self._camera.write_frame(frame.image) + + async def join(self): + # Transport already joined, ignore. + if self._joined or self._joining: + return + + logger.info(f"Joining {self._room_url}") + + self._joining = True + + # For performance reasons, never subscribe to video streams (unless a + # video renderer is registered). + self._client.update_subscription_profiles({ + "base": { + "camera": "unsubscribed", + "screenVideo": "unsubscribed" + } + }) + + self._client.set_user_name(self._bot_name) + + try: + (data, error) = await self._join() + + if not error: + self._joined = True + self._joining = False + + logger.info(f"Joined {self._room_url}") + + if self._token and self._params.transcription_enabled: + logger.info( + f"Enabling transcription with settings {self._params.transcription_settings}") + self._client.start_transcription( + self._params.transcription_settings.model_dump()) + + await self._callbacks.on_joined(data["participants"]["local"]) + else: + error_msg = f"Error joining {self._room_url}: {error}" + logger.error(error_msg) + await self._callbacks.on_error(error_msg) + except asyncio.TimeoutError: + error_msg = f"Time out joining {self._room_url}" + logger.error(error_msg) + await self._callbacks.on_error(error_msg) + + async def _join(self): + future = self._loop.create_future() + + def handle_join_response(data, error): + if not future.cancelled(): + future.get_loop().call_soon_threadsafe(future.set_result, (data, error)) + + self._client.join( + self._room_url, + self._token, + completion=handle_join_response, + client_settings={ + "inputs": { + "camera": { + "isEnabled": self._params.camera_out_enabled, + "settings": { + "deviceId": "camera", + }, + }, + "microphone": { + "isEnabled": self._params.audio_out_enabled, + "settings": { + "deviceId": "mic", + "customConstraints": { + "autoGainControl": {"exact": False}, + "echoCancellation": {"exact": False}, + "noiseSuppression": {"exact": False}, + }, + }, + }, + }, + "publishing": { + "camera": { + "sendSettings": { + "maxQuality": "low", + "encodings": { + "low": { + "maxBitrate": self._params.camera_out_bitrate, + "maxFramerate": self._params.camera_out_framerate, + } + }, + } + } + }, + }) + + return await asyncio.wait_for(future, timeout=10) + + async def leave(self): + # Transport not joined, ignore. + if not self._joined or self._leaving: + return + + self._joined = False + self._leaving = True + + logger.info(f"Leaving {self._room_url}") + + if self._params.transcription_enabled: + self._client.stop_transcription() + + try: + error = await self._leave() + if not error: + self._leaving = False + logger.info(f"Left {self._room_url}") + await self._callbacks.on_left() + else: + error_msg = f"Error leaving {self._room_url}: {error}" + logger.error(error_msg) + await self._callbacks.on_error(error_msg) + except asyncio.TimeoutError: + error_msg = f"Time out leaving {self._room_url}" + logger.error(error_msg) + await self._callbacks.on_error(error_msg) + + async def _leave(self): + future = self._loop.create_future() + + def handle_leave_response(error): + if not future.cancelled(): + future.get_loop().call_soon_threadsafe(future.set_result, error) + + self._client.leave(completion=handle_leave_response) + + return await asyncio.wait_for(future, timeout=10) + + async def cleanup(self): + await self._loop.run_in_executor(self._executor, self._cleanup) + + def _cleanup(self): + if self._client: + self._client.release() + self._client = None + + def participants(self): + return self._client.participants() + + def participant_counts(self): + return self._client.participant_counts() + + def start_dialout(self, settings): + self._client.start_dialout(settings) + + def stop_dialout(self, participant_id): + self._client.stop_dialout(participant_id) + + def start_recording(self, streaming_settings, stream_id, force_new): + self._client.start_recording(streaming_settings, stream_id, force_new) + + def stop_recording(self, stream_id): + self._client.stop_recording(stream_id) + + def capture_participant_transcription(self, participant_id: str, callback: Callable): + if not self._params.transcription_enabled: + return + + self._transcription_renderers[participant_id] = callback + + def capture_participant_video( + self, + participant_id: str, + callback: Callable, + framerate: int = 30, + video_source: str = "camera", + color_format: str = "RGB"): + # Only enable camera subscription on this participant + self._client.update_subscriptions(participant_settings={ + participant_id: { + "media": "subscribed" + } + }) + + self._video_renderers[participant_id] = callback + + self._client.set_video_renderer( + participant_id, + self._video_frame_received, + video_source=video_source, + color_format=color_format) + + # + # + # Daily (EventHandler) + # + + def on_app_message(self, message: Any, sender: str): + self._call_async_callback(self._callbacks.on_app_message, message, sender) + + def on_call_state_updated(self, state: str): + self._call_async_callback(self._callbacks.on_call_state_updated, state) + + def on_dialin_ready(self, sip_endpoint: str): + self._call_async_callback(self._callbacks.on_dialin_ready, sip_endpoint) + + def on_dialout_answered(self, data: Any): + self._call_async_callback(self._callbacks.on_dialout_answered, data) + + def on_dialout_connected(self, data: Any): + self._call_async_callback(self._callbacks.on_dialout_connected, data) + + def on_dialout_stopped(self, data: Any): + self._call_async_callback(self._callbacks.on_dialout_stopped, data) + + def on_dialout_error(self, data: Any): + self._call_async_callback(self._callbacks.on_dialout_error, data) + + def on_dialout_warning(self, data: Any): + self._call_async_callback(self._callbacks.on_dialout_warning, data) + + def on_participant_joined(self, participant): + id = participant["id"] + logger.info(f"Participant joined {id}") + + if not self._other_participant_has_joined: + self._other_participant_has_joined = True + self._call_async_callback(self._callbacks.on_first_participant_joined, participant) + + self._call_async_callback(self._callbacks.on_participant_joined, participant) + + def on_participant_left(self, participant, reason): + id = participant["id"] + logger.info(f"Participant left {id}") + + self._call_async_callback(self._callbacks.on_participant_left, participant, reason) + + def on_transcription_message(self, message: Mapping[str, Any]): + participant_id = "" + if "participantId" in message: + participant_id = message["participantId"] + + if participant_id in self._transcription_renderers: + callback = self._transcription_renderers[participant_id] + self._call_async_callback(callback, participant_id, message) + + def on_transcription_error(self, message): + logger.error(f"Transcription error: {message}") + + def on_transcription_started(self, status): + logger.debug(f"Transcription started: {status}") + + def on_transcription_stopped(self, stopped_by, stopped_by_error): + logger.debug("Transcription stopped") + + # + # Daily (CallClient callbacks) + # + + def _video_frame_received(self, participant_id, video_frame): + callback = self._video_renderers[participant_id] + self._call_async_callback( + callback, + participant_id, + video_frame.buffer, + (video_frame.width, + video_frame.height), + video_frame.color_format) + + def _call_async_callback(self, callback, *args): + future = asyncio.run_coroutine_threadsafe(callback(*args), self._loop) + future.result() + + +class DailyInputTransport(BaseInputTransport): + + def __init__(self, client: DailyTransportClient, params: DailyParams, **kwargs): + super().__init__(params, **kwargs) + + self._client = client + + self._video_renderers = {} + + self._vad_analyzer: VADAnalyzer | None = params.vad_analyzer + if params.vad_enabled and not params.vad_analyzer: + self._vad_analyzer = WebRTCVADAnalyzer( + sample_rate=self._params.audio_in_sample_rate, + num_channels=self._params.audio_in_channels) + + async def start(self, frame: StartFrame): + # Parent start. + await super().start(frame) + # Join the room. + await self._client.join() + # Create audio task. It reads audio frames from Daily and push them + # internally for VAD processing. + if self._params.audio_in_enabled or self._params.vad_enabled: + self._audio_in_task = self.get_event_loop().create_task(self._audio_in_task_handler()) + + async def stop(self): + # Parent stop. + await super().stop() + # Leave the room. + await self._client.leave() + # Stop audio thread. + if self._params.audio_in_enabled or self._params.vad_enabled: + self._audio_in_task.cancel() + await self._audio_in_task + + async def cleanup(self): + await super().cleanup() + await self._client.cleanup() + + def vad_analyzer(self) -> VADAnalyzer | None: + return self._vad_analyzer + + # + # FrameProcessor + # + + async def process_frame(self, frame: Frame, direction: FrameDirection): + await super().process_frame(frame, direction) + + if isinstance(frame, UserImageRequestFrame): + self.request_participant_image(frame.user_id) + + # + # Frames + # + + async def push_transcription_frame(self, frame: TranscriptionFrame | InterimTranscriptionFrame): + await self._internal_push_frame(frame) + + async def push_app_message(self, message: Any, sender: str): + frame = DailyTransportMessageFrame(message=message, participant_id=sender) + await self._internal_push_frame(frame) + + # + # Audio in + # + + async def _audio_in_task_handler(self): + while True: + try: + frame = await self._client.read_next_audio_frame() + if frame: + await self.push_audio_frame(frame) + except asyncio.CancelledError: + break + + # + # Camera in + # + + def capture_participant_video( + self, + participant_id: str, + framerate: int = 30, + video_source: str = "camera", + color_format: str = "RGB"): + self._video_renderers[participant_id] = { + "framerate": framerate, + "timestamp": 0, + "render_next_frame": False, + } + + self._client.capture_participant_video( + participant_id, + self._on_participant_video_frame, + framerate, + video_source, + color_format + ) + + def request_participant_image(self, participant_id: str): + if participant_id in self._video_renderers: + self._video_renderers[participant_id]["render_next_frame"] = True + + async def _on_participant_video_frame(self, participant_id: str, buffer, size, format): + render_frame = False + + curr_time = time.time() + prev_time = self._video_renderers[participant_id]["timestamp"] or curr_time + framerate = self._video_renderers[participant_id]["framerate"] + + if framerate > 0: + next_time = prev_time + 1 / framerate + render_frame = (curr_time - next_time) < 0.1 + elif self._video_renderers[participant_id]["render_next_frame"]: + self._video_renderers[participant_id]["render_next_frame"] = False + render_frame = True + + if render_frame: + frame = UserImageRawFrame( + user_id=participant_id, + image=buffer, + size=size, + format=format) + await self._internal_push_frame(frame) + + self._video_renderers[participant_id]["timestamp"] = curr_time + + +class DailyOutputTransport(BaseOutputTransport): + + def __init__(self, client: DailyTransportClient, params: DailyParams, **kwargs): + super().__init__(params, **kwargs) + + self._client = client + + async def start(self, frame: StartFrame): + # Parent start. + await super().start(frame) + # Join the room. + await self._client.join() + + async def stop(self): + # Parent stop. + await super().stop() + # Leave the room. + await self._client.leave() + + async def cleanup(self): + await super().cleanup() + await self._client.cleanup() + + async def send_message(self, frame: DailyTransportMessageFrame): + await self._client.send_message(frame) + + async def send_metrics(self, frame: MetricsFrame): + message = DailyTransportMessageFrame(message={ + "type": "pipecat-metrics", + "metrics": { + "ttfb": frame.ttfb or [], + "processing": frame.processing or [], + }, + }) + await self._client.send_message(message) + + async def write_raw_audio_frames(self, frames: bytes): + await self._client.write_raw_audio_frames(frames) + + async def write_frame_to_camera(self, frame: ImageRawFrame): + await self._client.write_frame_to_camera(frame) + + +class DailyTransport(BaseTransport): + + def __init__( + self, + room_url: str, + token: str | None, + bot_name: str, + params: DailyParams, + input_name: str | None = None, + output_name: str | None = None, + loop: asyncio.AbstractEventLoop | None = None): + super().__init__(input_name=input_name, output_name=output_name, loop=loop) + + callbacks = DailyCallbacks( + on_joined=self._on_joined, + on_left=self._on_left, + on_error=self._on_error, + on_app_message=self._on_app_message, + on_call_state_updated=self._on_call_state_updated, + on_dialin_ready=self._on_dialin_ready, + on_dialout_answered=self._on_dialout_answered, + on_dialout_connected=self._on_dialout_connected, + on_dialout_stopped=self._on_dialout_stopped, + on_dialout_error=self._on_dialout_error, + on_dialout_warning=self._on_dialout_warning, + on_first_participant_joined=self._on_first_participant_joined, + on_participant_joined=self._on_participant_joined, + on_participant_left=self._on_participant_left, + ) + self._params = params + + self._client = DailyTransportClient( + room_url, token, bot_name, params, callbacks, self._loop) + self._input: DailyInputTransport | None = None + self._output: DailyOutputTransport | None = None + + # Register supported handlers. The user will only be able to register + # these handlers. + self._register_event_handler("on_joined") + self._register_event_handler("on_left") + self._register_event_handler("on_app_message") + self._register_event_handler("on_call_state_updated") + self._register_event_handler("on_dialin_ready") + self._register_event_handler("on_dialout_answered") + self._register_event_handler("on_dialout_connected") + self._register_event_handler("on_dialout_stopped") + self._register_event_handler("on_dialout_error") + self._register_event_handler("on_dialout_warning") + self._register_event_handler("on_first_participant_joined") + self._register_event_handler("on_participant_joined") + self._register_event_handler("on_participant_left") + + # + # BaseTransport + # + + def input(self) -> FrameProcessor: + if not self._input: + self._input = DailyInputTransport(self._client, self._params, name=self._input_name) + return self._input + + def output(self) -> FrameProcessor: + if not self._output: + self._output = DailyOutputTransport(self._client, self._params, name=self._output_name) + return self._output + + # + # DailyTransport + # + + @ property + def participant_id(self) -> str: + return self._client.participant_id + + async def send_image(self, frame: ImageRawFrame | SpriteFrame): + if self._output: + await self._output.process_frame(frame, FrameDirection.DOWNSTREAM) + + async def send_audio(self, frame: AudioRawFrame): + if self._output: + await self._output.process_frame(frame, FrameDirection.DOWNSTREAM) + + def participants(self): + return self._client.participants() + + def participant_counts(self): + return self._client.participant_counts() + + def start_dialout(self, settings=None): + self._client.start_dialout(settings) + + def stop_dialout(self, participant_id): + self._client.stop_dialout(participant_id) + + def start_recording(self, streaming_settings=None, stream_id=None, force_new=None): + self._client.start_recording(streaming_settings, stream_id, force_new) + + def stop_recording(self, stream_id=None): + self._client.stop_recording(stream_id) + + def capture_participant_transcription(self, participant_id: str): + self._client.capture_participant_transcription( + participant_id, + self._on_transcription_message + ) + + def capture_participant_video( + self, + participant_id: str, + framerate: int = 30, + video_source: str = "camera", + color_format: str = "RGB"): + if self._input: + self._input.capture_participant_video( + participant_id, framerate, video_source, color_format) + + async def _on_joined(self, participant): + await self._call_event_handler("on_joined", participant) + + async def _on_left(self): + await self._call_event_handler("on_left") + + async def _on_error(self, error): + # TODO(aleix): Report error to input/output transports. The one managing + # the client should report the error. + pass + + async def _on_app_message(self, message: Any, sender: str): + if self._input: + await self._input.push_app_message(message, sender) + await self._call_event_handler("on_app_message", message, sender) + + async def _on_call_state_updated(self, state: str): + await self._call_event_handler("on_call_state_updated", state) + + async def _handle_dialin_ready(self, sip_endpoint: str): + if not self._params.dialin_settings: + return + + async with aiohttp.ClientSession() as session: + headers = { + "Authorization": f"Bearer {self._params.api_key}", + "Content-Type": "application/json" + } + data = { + "callId": self._params.dialin_settings.call_id, + "callDomain": self._params.dialin_settings.call_domain, + "sipUri": sip_endpoint + } + + url = f"{self._params.api_url}/dialin/pinlessCallUpdate" + + try: + async with session.post(url, headers=headers, json=data, timeout=10) as r: + if r.status != 200: + text = await r.text() + logger.error( + f"Unable to handle dialin-ready event (status: {r.status}, error: {text})") + return + + logger.debug("Event dialin-ready was handled successfully") + except asyncio.TimeoutError: + logger.error(f"Timeout handling dialin-ready event ({url})") + except Exception as e: + logger.exception(f"Error handling dialin-ready event ({url}): {e}") + + async def _on_dialin_ready(self, sip_endpoint): + if self._params.dialin_settings: + await self._handle_dialin_ready(sip_endpoint) + await self._call_event_handler("on_dialin_ready", sip_endpoint) + + async def _on_dialout_answered(self, data): + await self._call_event_handler("on_dialout_answered", data) + + async def _on_dialout_connected(self, data): + await self._call_event_handler("on_dialout_connected", data) + + async def _on_dialout_stopped(self, data): + await self._call_event_handler("on_dialout_stopped", data) + + async def _on_dialout_error(self, data): + await self._call_event_handler("on_dialout_error", data) + + async def _on_dialout_warning(self, data): + await self._call_event_handler("on_dialout_warning", data) + + async def _on_participant_joined(self, participant): + await self._call_event_handler("on_participant_joined", participant) + + async def _on_participant_left(self, participant, reason): + await self._call_event_handler("on_participant_left", participant, reason) + + async def _on_first_participant_joined(self, participant): + await self._call_event_handler("on_first_participant_joined", participant) + + async def _on_transcription_message(self, participant_id, message): + text = message["text"] + timestamp = message["timestamp"] + is_final = message["rawResponse"]["is_final"] + if is_final: + frame = TranscriptionFrame(text, participant_id, timestamp) + logger.debug(f"Transcription (from: {participant_id}): [{text}]") + else: + frame = InterimTranscriptionFrame(text, participant_id, timestamp) + + if self._input: + await self._input.push_transcription_frame(frame) diff --git a/pipecat/transports/services/helpers/__init__.py b/pipecat/transports/services/helpers/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/pipecat/transports/services/helpers/daily_rest.py b/pipecat/transports/services/helpers/daily_rest.py new file mode 100644 index 0000000000000000000000000000000000000000..a70b96380a2ba80d796d0fe8330fd9b48d5bc6b2 --- /dev/null +++ b/pipecat/transports/services/helpers/daily_rest.py @@ -0,0 +1,139 @@ +# +# Copyright (c) 2024, Daily +# +# SPDX-License-Identifier: BSD 2-Clause License +# + +""" +Daily REST Helpers + +Methods that wrap the Daily API to create rooms, check room URLs, and get meeting tokens. + +""" + +import requests +import time + +from urllib.parse import urlparse + +from pydantic import Field, BaseModel, ValidationError +from typing import Literal, Optional + + +class DailyRoomSipParams(BaseModel): + display_name: str = "sw-sip-dialin" + video: bool = False + sip_mode: str = "dial-in" + num_endpoints: int = 1 + + +class DailyRoomProperties(BaseModel, extra="allow"): + exp: float = Field(default_factory=lambda: time.time() + 5 * 60) + enable_chat: bool = False + enable_emoji_reactions: bool = False + eject_at_room_exp: bool = True + enable_dialout: Optional[bool] = None + sip: Optional[DailyRoomSipParams] = None + sip_uri: Optional[dict] = None + + @property + def sip_endpoint(self) -> str: + if not self.sip_uri: + return "" + else: + return "sip:%s" % self.sip_uri['endpoint'] + + +class DailyRoomParams(BaseModel): + name: Optional[str] = None + privacy: Literal['private', 'public'] = "public" + properties: DailyRoomProperties = DailyRoomProperties() + + +class DailyRoomObject(BaseModel): + id: str + name: str + api_created: bool + privacy: str + url: str + created_at: str + config: DailyRoomProperties + + +class DailyRESTHelper: + def __init__(self, daily_api_key: str, daily_api_url: str = "https://api.daily.co/v1"): + self.daily_api_key = daily_api_key + self.daily_api_url = daily_api_url + + def _get_name_from_url(self, room_url: str) -> str: + return urlparse(room_url).path[1:] + + def create_room(self, params: DailyRoomParams) -> DailyRoomObject: + res = requests.post( + f"{self.daily_api_url}/rooms", + headers={"Authorization": f"Bearer {self.daily_api_key}"}, + json={**params.model_dump(exclude_none=True)} + ) + + if res.status_code != 200: + raise Exception(f"Unable to create room: {res.text}") + + data = res.json() + + try: + room = DailyRoomObject(**data) + except ValidationError as e: + raise Exception(f"Invalid response: {e}") + + return room + + def _get_room_from_name(self, room_name: str) -> DailyRoomObject: + res: requests.Response = requests.get( + f"{self.daily_api_url}/rooms/{room_name}", + headers={"Authorization": f"Bearer {self.daily_api_key}"} + ) + + if res.status_code != 200: + raise Exception(f"Room not found: {room_name}") + + data = res.json() + + try: + room = DailyRoomObject(**data) + except ValidationError as e: + raise Exception(f"Invalid response: {e}") + + return room + + def get_room_from_url(self, room_url: str,) -> DailyRoomObject: + room_name = self._get_name_from_url(room_url) + return self._get_room_from_name(room_name) + + def get_token(self, room_url: str, expiry_time: float = 60 * 60, owner: bool = True) -> str: + if not room_url: + raise Exception( + "No Daily room specified. You must specify a Daily room in order a token to be generated.") + + expiration: float = time.time() + expiry_time + + room_name = self._get_name_from_url(room_url) + + res: requests.Response = requests.post( + f"{self.daily_api_url}/meeting-tokens", + headers={ + "Authorization": f"Bearer {self.daily_api_key}"}, + json={ + "properties": { + "room_name": room_name, + "is_owner": owner, + "exp": expiration + }}, + ) + + if res.status_code != 200: + raise Exception( + f"Failed to create meeting token: {res.status_code} {res.text}") + + token: str = res.json()["token"] + + return token diff --git a/pipecat/utils/__init__.py b/pipecat/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/pipecat/utils/audio.py b/pipecat/utils/audio.py new file mode 100644 index 0000000000000000000000000000000000000000..f103b55b9c45b06fe38cb19ce160bbb653de944f --- /dev/null +++ b/pipecat/utils/audio.py @@ -0,0 +1,54 @@ +# +# Copyright (c) 2024, Daily +# +# SPDX-License-Identifier: BSD 2-Clause License +# + +import audioop +import numpy as np +import pyloudnorm as pyln + + +def normalize_value(value, min_value, max_value): + normalized = (value - min_value) / (max_value - min_value) + normalized_clamped = max(0, min(1, normalized)) + return normalized_clamped + + +def calculate_audio_volume(audio: bytes, sample_rate: int) -> float: + audio_np = np.frombuffer(audio, dtype=np.int16) + audio_float = audio_np.astype(np.float64) + + block_size = audio_np.size / sample_rate + meter = pyln.Meter(sample_rate, block_size=block_size) + loudness = meter.integrated_loudness(audio_float) + + # Loudness goes from -20 to 80 (more or less), where -20 is quiet and 80 is + # loud. + loudness = normalize_value(loudness, -20, 80) + + return loudness + + +def exp_smoothing(value: float, prev_value: float, factor: float) -> float: + return prev_value + factor * (value - prev_value) + + +def ulaw_8000_to_pcm_16000(ulaw_8000_bytes): + # Convert μ-law to PCM + pcm_8000_bytes = audioop.ulaw2lin(ulaw_8000_bytes, 2) + + # Resample from 8000 Hz to 16000 Hz + pcm_16000_bytes = audioop.ratecv(pcm_8000_bytes, 2, 1, 8000, 16000, None)[0] + + return pcm_16000_bytes + + +def pcm_16000_to_ulaw_8000(pcm_16000_bytes): + # Resample from 16000 Hz to 8000 Hz + pcm_8000_bytes = audioop.ratecv(pcm_16000_bytes, 2, 1, 16000, 8000, None)[0] + + # Convert PCM to μ-law + ulaw_8000_bytes = audioop.lin2ulaw(pcm_8000_bytes, 2) + + return ulaw_8000_bytes diff --git a/pipecat/utils/test_frame_processor.py b/pipecat/utils/test_frame_processor.py new file mode 100644 index 0000000000000000000000000000000000000000..cc4a35b181ad0792b0c389de25381af786b0812f --- /dev/null +++ b/pipecat/utils/test_frame_processor.py @@ -0,0 +1,43 @@ +from typing import List +from pipecat.processors.frame_processor import FrameProcessor + + +class TestException(Exception): + pass + + +class TestFrameProcessor(FrameProcessor): + def __init__(self, test_frames): + self.test_frames = test_frames + self._list_counter = 0 + super().__init__() + + async def process_frame(self, frame, direction): + await super().process_frame(frame, direction) + + if not self.test_frames[0]: # then we've run out of required frames but the generator is still going? + raise TestException(f"Oops, got an extra frame, {frame}") + if isinstance(self.test_frames[0], List): + # We need to consume frames until we see the next frame type after this + next_frame = self.test_frames[1] + if isinstance(frame, next_frame): + # we're done iterating the list I guess + print(f"TestFrameProcessor got expected list exit frame: {frame}") + # pop twice to get rid of the list, as well as the next frame + self.test_frames.pop(0) + self.test_frames.pop(0) + self.list_counter = 0 + else: + fl = self.test_frames[0] + fl_el = fl[self._list_counter % len(fl)] + if isinstance(frame, fl_el): + print(f"TestFrameProcessor got expected list frame: {frame}") + self._list_counter += 1 + else: + raise TestException(f"Inside a list, expected {fl_el} but got {frame}") + + else: + if not isinstance(frame, self.test_frames[0]): + raise TestException(f"Expected {self.test_frames[0]}, but got {frame}") + print(f"TestFrameProcessor got expected frame: {frame}") + self.test_frames.pop(0) diff --git a/pipecat/utils/utils.py b/pipecat/utils/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..190e20f1999079379099f02b108b1ef6a724ad8f --- /dev/null +++ b/pipecat/utils/utils.py @@ -0,0 +1,35 @@ +# +# Copyright (c) 2024, Daily +# +# SPDX-License-Identifier: BSD 2-Clause License +# + +from threading import Lock + +_COUNTS = {} +_COUNTS_MUTEX = Lock() + +_ID = 0 +_ID_MUTEX = Lock() + + +def obj_id() -> int: + global _ID, _ID_MUTEX + with _ID_MUTEX: + _ID += 1 + return _ID + + +def obj_count(obj) -> int: + global _COUNTS, COUNTS_MUTEX + name = obj.__class__.__name__ + with _COUNTS_MUTEX: + if name not in _COUNTS: + _COUNTS[name] = 0 + else: + _COUNTS[name] += 1 + return _COUNTS[name] + + +def exp_smoothing(value: float, prev_value: float, factor: float) -> float: + return prev_value + factor * (value - prev_value) diff --git a/pipecat/vad/__init__.py b/pipecat/vad/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/pipecat/vad/silero.py b/pipecat/vad/silero.py new file mode 100644 index 0000000000000000000000000000000000000000..dbe58382eef0fc0fb0a117037419f6722f3684b5 --- /dev/null +++ b/pipecat/vad/silero.py @@ -0,0 +1,132 @@ +# +# Copyright (c) 2024, Daily +# +# SPDX-License-Identifier: BSD 2-Clause License +# + +import time + +import numpy as np + +from pipecat.frames.frames import AudioRawFrame, Frame, UserStartedSpeakingFrame, UserStoppedSpeakingFrame +from pipecat.processors.frame_processor import FrameDirection, FrameProcessor +from pipecat.vad.vad_analyzer import VADAnalyzer, VADParams, VADState + +from loguru import logger + +try: + import torch + # We don't use torchaudio here, but we need to try importing it because + # Silero uses it. + import torchaudio + + torch.set_num_threads(1) + +except ModuleNotFoundError as e: + logger.error(f"Exception: {e}") + logger.error("In order to use Silero VAD, you need to `pip install pipecat-ai[silero]`.") + raise Exception(f"Missing module(s): {e}") + +# How often should we reset internal model state +_MODEL_RESET_STATES_TIME = 5.0 + + +class SileroVADAnalyzer(VADAnalyzer): + + def __init__( + self, + *, + sample_rate: int = 16000, + version: str = "v5.0", + params: VADParams = VADParams()): + super().__init__(sample_rate=sample_rate, num_channels=1, params=params) + + if sample_rate != 16000 and sample_rate != 8000: + raise ValueError("Silero VAD sample rate needs to be 16000 or 8000") + + logger.debug("Loading Silero VAD model...") + + (self._model, _) = torch.hub.load(repo_or_dir=f"snakers4/silero-vad:{version}", + model="silero_vad", + force_reload=False, + trust_repo=True) + + self._last_reset_time = 0 + + logger.debug("Loaded Silero VAD") + + # + # VADAnalyzer + # + + def num_frames_required(self) -> int: + return 512 if self.sample_rate == 16000 else 256 + + def voice_confidence(self, buffer) -> float: + try: + audio_int16 = np.frombuffer(buffer, np.int16) + # Divide by 32768 because we have signed 16-bit data. + audio_float32 = np.frombuffer(audio_int16, dtype=np.int16).astype(np.float32) / 32768.0 + new_confidence = self._model(torch.from_numpy(audio_float32), self.sample_rate).item() + + # We need to reset the model from time to time because it doesn't + # really need all the data and memory will keep growing otherwise. + curr_time = time.time() + diff_time = curr_time - self._last_reset_time + if diff_time >= _MODEL_RESET_STATES_TIME: + self._model.reset_states() + self._last_reset_time = curr_time + + return new_confidence + except Exception as e: + # This comes from an empty audio array + logger.exception(f"Error analyzing audio with Silero VAD: {e}") + return 0 + + +class SileroVAD(FrameProcessor): + + def __init__( + self, + *, + sample_rate: int = 16000, + version: str = "v5.0", + vad_params: VADParams = VADParams(), + audio_passthrough: bool = False): + super().__init__() + + self._vad_analyzer = SileroVADAnalyzer( + sample_rate=sample_rate, version=version, params=vad_params) + self._audio_passthrough = audio_passthrough + + self._processor_vad_state: VADState = VADState.QUIET + + # + # FrameProcessor + # + + async def process_frame(self, frame: Frame, direction: FrameDirection): + await super().process_frame(frame, direction) + + if isinstance(frame, AudioRawFrame): + await self._analyze_audio(frame) + if self._audio_passthrough: + await self.push_frame(frame, direction) + else: + await self.push_frame(frame, direction) + + async def _analyze_audio(self, frame: AudioRawFrame): + # Check VAD and push event if necessary. We just care about changes + # from QUIET to SPEAKING and vice versa. + new_vad_state = self._vad_analyzer.analyze_audio(frame.audio) + if new_vad_state != self._processor_vad_state and new_vad_state != VADState.STARTING and new_vad_state != VADState.STOPPING: + new_frame = None + + if new_vad_state == VADState.SPEAKING: + new_frame = UserStartedSpeakingFrame() + elif new_vad_state == VADState.QUIET: + new_frame = UserStoppedSpeakingFrame() + + if new_frame: + await self.push_frame(new_frame) + self._processor_vad_state = new_vad_state diff --git a/pipecat/vad/vad_analyzer.py b/pipecat/vad/vad_analyzer.py new file mode 100644 index 0000000000000000000000000000000000000000..dd12f52a13b3287d2407ee625164b0548db5d9cb --- /dev/null +++ b/pipecat/vad/vad_analyzer.py @@ -0,0 +1,120 @@ +# +# Copyright (c) 2024, Daily +# +# SPDX-License-Identifier: BSD 2-Clause License +# + +from abc import abstractmethod +from enum import Enum + +from pydantic.main import BaseModel + +from pipecat.utils.audio import calculate_audio_volume, exp_smoothing + + +class VADState(Enum): + QUIET = 1 + STARTING = 2 + SPEAKING = 3 + STOPPING = 4 + + +class VADParams(BaseModel): + confidence: float = 0.7 + start_secs: float = 0.2 + stop_secs: float = 0.8 + min_volume: float = 0.6 + + +class VADAnalyzer: + + def __init__(self, *, sample_rate: int, num_channels: int, params: VADParams): + self._sample_rate = sample_rate + self._num_channels = num_channels + self._params = params + self._vad_frames = self.num_frames_required() + self._vad_frames_num_bytes = self._vad_frames * num_channels * 2 + + vad_frames_per_sec = self._vad_frames / self._sample_rate + + self._vad_start_frames = round(self._params.start_secs / vad_frames_per_sec) + self._vad_stop_frames = round(self._params.stop_secs / vad_frames_per_sec) + self._vad_starting_count = 0 + self._vad_stopping_count = 0 + self._vad_state: VADState = VADState.QUIET + + self._vad_buffer = b"" + + # Volume exponential smoothing + self._smoothing_factor = 0.2 + self._prev_volume = 0 + + @property + def sample_rate(self): + return self._sample_rate + + @abstractmethod + def num_frames_required(self) -> int: + pass + + @abstractmethod + def voice_confidence(self, buffer) -> float: + pass + + def _get_smoothed_volume(self, audio: bytes) -> float: + volume = calculate_audio_volume(audio, self._sample_rate) + return exp_smoothing(volume, self._prev_volume, self._smoothing_factor) + + def analyze_audio(self, buffer) -> VADState: + self._vad_buffer += buffer + + num_required_bytes = self._vad_frames_num_bytes + if len(self._vad_buffer) < num_required_bytes: + return self._vad_state + + audio_frames = self._vad_buffer[:num_required_bytes] + self._vad_buffer = self._vad_buffer[num_required_bytes:] + + confidence = self.voice_confidence(audio_frames) + + volume = self._get_smoothed_volume(audio_frames) + self._prev_volume = volume + + speaking = confidence >= self._params.confidence and volume >= self._params.min_volume + + if speaking: + match self._vad_state: + case VADState.QUIET: + self._vad_state = VADState.STARTING + self._vad_starting_count = 1 + case VADState.STARTING: + self._vad_starting_count += 1 + case VADState.STOPPING: + self._vad_state = VADState.SPEAKING + self._vad_stopping_count = 0 + else: + match self._vad_state: + case VADState.STARTING: + self._vad_state = VADState.QUIET + self._vad_starting_count = 0 + case VADState.SPEAKING: + self._vad_state = VADState.STOPPING + self._vad_stopping_count = 1 + case VADState.STOPPING: + self._vad_stopping_count += 1 + + if ( + self._vad_state == VADState.STARTING + and self._vad_starting_count >= self._vad_start_frames + ): + self._vad_state = VADState.SPEAKING + self._vad_starting_count = 0 + + if ( + self._vad_state == VADState.STOPPING + and self._vad_stopping_count >= self._vad_stop_frames + ): + self._vad_state = VADState.QUIET + self._vad_stopping_count = 0 + + return self._vad_state