|
|
|
|
|
|
|
|
|
|
|
|
|
import asyncio
|
|
|
|
from itertools import chain
|
|
from typing import List
|
|
|
|
from pipecat.pipeline.base_pipeline import BasePipeline
|
|
from pipecat.pipeline.pipeline import Pipeline
|
|
from pipecat.processors.frame_processor import FrameDirection, FrameProcessor
|
|
from pipecat.frames.frames import Frame
|
|
|
|
from loguru import logger
|
|
|
|
|
|
class Source(FrameProcessor):
|
|
|
|
def __init__(self, upstream_queue: asyncio.Queue):
|
|
super().__init__()
|
|
self._up_queue = upstream_queue
|
|
|
|
async def process_frame(self, frame: Frame, direction: FrameDirection):
|
|
await super().process_frame(frame, direction)
|
|
|
|
match direction:
|
|
case FrameDirection.UPSTREAM:
|
|
await self._up_queue.put(frame)
|
|
case FrameDirection.DOWNSTREAM:
|
|
await self.push_frame(frame, direction)
|
|
|
|
|
|
class Sink(FrameProcessor):
|
|
|
|
def __init__(self, downstream_queue: asyncio.Queue):
|
|
super().__init__()
|
|
self._down_queue = downstream_queue
|
|
|
|
async def process_frame(self, frame: Frame, direction: FrameDirection):
|
|
await super().process_frame(frame, direction)
|
|
|
|
match direction:
|
|
case FrameDirection.UPSTREAM:
|
|
await self.push_frame(frame, direction)
|
|
case FrameDirection.DOWNSTREAM:
|
|
await self._down_queue.put(frame)
|
|
|
|
|
|
class ParallelTask(BasePipeline):
|
|
def __init__(self, *args):
|
|
super().__init__()
|
|
|
|
if len(args) == 0:
|
|
raise Exception(f"ParallelTask needs at least one argument")
|
|
|
|
self._sinks = []
|
|
self._pipelines = []
|
|
|
|
self._up_queue = asyncio.Queue()
|
|
self._down_queue = asyncio.Queue()
|
|
|
|
logger.debug(f"Creating {self} pipelines")
|
|
for processors in args:
|
|
if not isinstance(processors, list):
|
|
raise TypeError(f"ParallelTask argument {processors} is not a list")
|
|
|
|
|
|
source = Source(self._up_queue)
|
|
sink = Sink(self._down_queue)
|
|
processors: List[FrameProcessor] = [source] + processors
|
|
processors.append(sink)
|
|
|
|
|
|
self._sinks.append(sink)
|
|
|
|
|
|
pipeline = Pipeline(processors)
|
|
self._pipelines.append(pipeline)
|
|
logger.debug(f"Finished creating {self} pipelines")
|
|
|
|
|
|
|
|
|
|
|
|
def processors_with_metrics(self) -> List[FrameProcessor]:
|
|
return list(chain.from_iterable(p.processors_with_metrics() for p in self._pipelines))
|
|
|
|
|
|
|
|
|
|
|
|
async def process_frame(self, frame: Frame, direction: FrameDirection):
|
|
await super().process_frame(frame, direction)
|
|
|
|
if direction == FrameDirection.UPSTREAM:
|
|
|
|
await asyncio.gather(*[s.process_frame(frame, direction) for s in self._sinks])
|
|
elif direction == FrameDirection.DOWNSTREAM:
|
|
|
|
await asyncio.gather(*[p.process_frame(frame, direction) for p in self._pipelines])
|
|
|
|
seen_ids = set()
|
|
while not self._up_queue.empty():
|
|
frame = await self._up_queue.get()
|
|
if frame and frame.id not in seen_ids:
|
|
await self.push_frame(frame, FrameDirection.UPSTREAM)
|
|
seen_ids.add(frame.id)
|
|
self._up_queue.task_done()
|
|
|
|
seen_ids = set()
|
|
while not self._down_queue.empty():
|
|
frame = await self._down_queue.get()
|
|
if frame and frame.id not in seen_ids:
|
|
await self.push_frame(frame, FrameDirection.DOWNSTREAM)
|
|
seen_ids.add(frame.id)
|
|
self._down_queue.task_done()
|
|
|