|
|
|
|
|
|
|
|
|
|
|
|
|
import aiohttp
|
|
import base64
|
|
import io
|
|
import json
|
|
|
|
from typing import AsyncGenerator, List, Literal
|
|
|
|
from loguru import logger
|
|
from PIL import Image
|
|
|
|
from pipecat.frames.frames import (
|
|
AudioRawFrame,
|
|
ErrorFrame,
|
|
Frame,
|
|
LLMFullResponseEndFrame,
|
|
LLMFullResponseStartFrame,
|
|
LLMMessagesFrame,
|
|
LLMResponseEndFrame,
|
|
LLMResponseStartFrame,
|
|
TextFrame,
|
|
URLImageRawFrame,
|
|
VisionImageRawFrame
|
|
)
|
|
from pipecat.processors.aggregators.openai_llm_context import (
|
|
OpenAILLMContext,
|
|
OpenAILLMContextFrame
|
|
)
|
|
from pipecat.processors.frame_processor import FrameDirection
|
|
from pipecat.services.ai_services import (
|
|
ImageGenService,
|
|
LLMService,
|
|
TTSService
|
|
)
|
|
|
|
try:
|
|
from openai import AsyncOpenAI, AsyncStream, BadRequestError
|
|
from openai.types.chat import (
|
|
ChatCompletionChunk,
|
|
ChatCompletionFunctionMessageParam,
|
|
ChatCompletionMessageParam,
|
|
ChatCompletionToolParam
|
|
)
|
|
except ModuleNotFoundError as e:
|
|
logger.error(f"Exception: {e}")
|
|
logger.error(
|
|
"In order to use OpenAI, you need to `pip install pipecat-ai[openai]`. Also, set `OPENAI_API_KEY` environment variable.")
|
|
raise Exception(f"Missing module: {e}")
|
|
|
|
|
|
class OpenAIUnhandledFunctionException(Exception):
|
|
pass
|
|
|
|
|
|
class BaseOpenAILLMService(LLMService):
|
|
"""This is the base for all services that use the AsyncOpenAI client.
|
|
|
|
This service consumes OpenAILLMContextFrame frames, which contain a reference
|
|
to an OpenAILLMContext frame. The OpenAILLMContext object defines the context
|
|
sent to the LLM for a completion. This includes user, assistant and system messages
|
|
as well as tool choices and the tool, which is used if requesting function
|
|
calls from the LLM.
|
|
"""
|
|
|
|
def __init__(self, *, model: str, api_key=None, base_url=None, **kwargs):
|
|
super().__init__(**kwargs)
|
|
self._model: str = model
|
|
self._client = self.create_client(api_key=api_key, base_url=base_url, **kwargs)
|
|
|
|
def create_client(self, api_key=None, base_url=None, **kwargs):
|
|
return AsyncOpenAI(api_key=api_key, base_url=base_url)
|
|
|
|
def can_generate_metrics(self) -> bool:
|
|
return True
|
|
|
|
async def get_chat_completions(
|
|
self,
|
|
context: OpenAILLMContext,
|
|
messages: List[ChatCompletionMessageParam]) -> AsyncStream[ChatCompletionChunk]:
|
|
chunks = await self._client.chat.completions.create(
|
|
model=self._model,
|
|
stream=True,
|
|
messages=messages,
|
|
tools=context.tools,
|
|
tool_choice=context.tool_choice,
|
|
)
|
|
return chunks
|
|
|
|
async def _stream_chat_completions(
|
|
self, context: OpenAILLMContext) -> AsyncStream[ChatCompletionChunk]:
|
|
logger.debug(f"Generating chat: {context.get_messages_json()}")
|
|
|
|
messages: List[ChatCompletionMessageParam] = context.get_messages()
|
|
|
|
|
|
for message in messages:
|
|
if message.get("mime_type") == "image/jpeg":
|
|
encoded_image = base64.b64encode(message["data"].getvalue()).decode("utf-8")
|
|
text = message["content"]
|
|
message["content"] = [
|
|
{"type": "text", "text": text},
|
|
{"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{encoded_image}"}}
|
|
]
|
|
del message["data"]
|
|
del message["mime_type"]
|
|
|
|
chunks = await self.get_chat_completions(context, messages)
|
|
|
|
return chunks
|
|
|
|
async def _process_context(self, context: OpenAILLMContext):
|
|
function_name = ""
|
|
arguments = ""
|
|
tool_call_id = ""
|
|
|
|
await self.start_ttfb_metrics()
|
|
|
|
chunk_stream: AsyncStream[ChatCompletionChunk] = (
|
|
await self._stream_chat_completions(context)
|
|
)
|
|
|
|
async for chunk in chunk_stream:
|
|
if len(chunk.choices) == 0:
|
|
continue
|
|
|
|
await self.stop_ttfb_metrics()
|
|
|
|
if chunk.choices[0].delta.tool_calls:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
tool_call = chunk.choices[0].delta.tool_calls[0]
|
|
if tool_call.function and tool_call.function.name:
|
|
function_name += tool_call.function.name
|
|
tool_call_id = tool_call.id
|
|
await self.call_start_function(function_name)
|
|
if tool_call.function and tool_call.function.arguments:
|
|
|
|
arguments += tool_call.function.arguments
|
|
elif chunk.choices[0].delta.content:
|
|
await self.push_frame(LLMResponseStartFrame())
|
|
await self.push_frame(TextFrame(chunk.choices[0].delta.content))
|
|
await self.push_frame(LLMResponseEndFrame())
|
|
|
|
|
|
|
|
|
|
|
|
if function_name and arguments:
|
|
if self.has_function(function_name):
|
|
await self._handle_function_call(context, tool_call_id, function_name, arguments)
|
|
else:
|
|
raise OpenAIUnhandledFunctionException(
|
|
f"The LLM tried to call a function named '{function_name}', but there isn't a callback registered for that function.")
|
|
|
|
async def _handle_function_call(
|
|
self,
|
|
context,
|
|
tool_call_id,
|
|
function_name,
|
|
arguments
|
|
):
|
|
arguments = json.loads(arguments)
|
|
result = await self.call_function(function_name, arguments)
|
|
arguments = json.dumps(arguments)
|
|
if isinstance(result, (str, dict)):
|
|
|
|
tool_call = ChatCompletionFunctionMessageParam({
|
|
"role": "assistant",
|
|
"tool_calls": [
|
|
{
|
|
"id": tool_call_id,
|
|
"function": {
|
|
"arguments": arguments,
|
|
"name": function_name
|
|
},
|
|
"type": "function"
|
|
}
|
|
]
|
|
|
|
})
|
|
context.add_message(tool_call)
|
|
if isinstance(result, dict):
|
|
result = json.dumps(result)
|
|
tool_result = ChatCompletionToolParam({
|
|
"tool_call_id": tool_call_id,
|
|
"role": "tool",
|
|
"content": result
|
|
})
|
|
context.add_message(tool_result)
|
|
|
|
await self._process_context(context)
|
|
elif isinstance(result, list):
|
|
|
|
for msg in result:
|
|
context.add_message(msg)
|
|
await self._process_context(context)
|
|
elif isinstance(result, type(None)):
|
|
pass
|
|
else:
|
|
raise TypeError(f"Unknown return type from function callback: {type(result)}")
|
|
|
|
async def process_frame(self, frame: Frame, direction: FrameDirection):
|
|
await super().process_frame(frame, direction)
|
|
|
|
context = None
|
|
if isinstance(frame, OpenAILLMContextFrame):
|
|
context: OpenAILLMContext = frame.context
|
|
elif isinstance(frame, LLMMessagesFrame):
|
|
context = OpenAILLMContext.from_messages(frame.messages)
|
|
elif isinstance(frame, VisionImageRawFrame):
|
|
context = OpenAILLMContext.from_image_frame(frame)
|
|
else:
|
|
await self.push_frame(frame, direction)
|
|
|
|
if context:
|
|
await self.push_frame(LLMFullResponseStartFrame())
|
|
await self.start_processing_metrics()
|
|
await self._process_context(context)
|
|
await self.stop_processing_metrics()
|
|
await self.push_frame(LLMFullResponseEndFrame())
|
|
|
|
|
|
class OpenAILLMService(BaseOpenAILLMService):
|
|
|
|
def __init__(self, *, model: str = "gpt-4o", **kwargs):
|
|
super().__init__(model=model, **kwargs)
|
|
|
|
|
|
class OpenAIImageGenService(ImageGenService):
|
|
|
|
def __init__(
|
|
self,
|
|
*,
|
|
image_size: Literal["256x256", "512x512", "1024x1024", "1792x1024", "1024x1792"],
|
|
aiohttp_session: aiohttp.ClientSession,
|
|
api_key: str,
|
|
model: str = "dall-e-3",
|
|
):
|
|
super().__init__()
|
|
self._model = model
|
|
self._image_size = image_size
|
|
self._client = AsyncOpenAI(api_key=api_key)
|
|
self._aiohttp_session = aiohttp_session
|
|
|
|
async def run_image_gen(self, prompt: str) -> AsyncGenerator[Frame, None]:
|
|
logger.debug(f"Generating image from prompt: {prompt}")
|
|
|
|
image = await self._client.images.generate(
|
|
prompt=prompt,
|
|
model=self._model,
|
|
n=1,
|
|
size=self._image_size
|
|
)
|
|
|
|
image_url = image.data[0].url
|
|
|
|
if not image_url:
|
|
logger.error(f"{self} No image provided in response: {image}")
|
|
yield ErrorFrame("Image generation failed")
|
|
return
|
|
|
|
|
|
async with self._aiohttp_session.get(image_url) as response:
|
|
image_stream = io.BytesIO(await response.content.read())
|
|
image = Image.open(image_stream)
|
|
frame = URLImageRawFrame(image_url, image.tobytes(), image.size, image.format)
|
|
yield frame
|
|
|
|
|
|
class OpenAITTSService(TTSService):
|
|
"""This service uses the OpenAI TTS API to generate audio from text.
|
|
The returned audio is PCM encoded at 24kHz. When using the DailyTransport, set the sample rate in the DailyParams accordingly:
|
|
```
|
|
DailyParams(
|
|
audio_out_enabled=True,
|
|
audio_out_sample_rate=24_000,
|
|
)
|
|
```
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
*,
|
|
api_key: str | None = None,
|
|
base_url: str | None = None,
|
|
sample_rate: int = 24_000,
|
|
voice: Literal["alloy", "echo", "fable", "onyx", "nova", "shimmer"] = "alloy",
|
|
model: Literal["tts-1", "tts-1-hd"] = "tts-1",
|
|
**kwargs):
|
|
super().__init__(**kwargs)
|
|
|
|
self._voice = voice
|
|
self._model = model
|
|
self.sample_rate=sample_rate
|
|
self._client = AsyncOpenAI(api_key=api_key,base_url=base_url)
|
|
|
|
def can_generate_metrics(self) -> bool:
|
|
return True
|
|
|
|
async def run_tts(self, text: str) -> AsyncGenerator[Frame, None]:
|
|
logger.debug(f"Generating TTS: [{text}]")
|
|
|
|
try:
|
|
await self.start_ttfb_metrics()
|
|
|
|
async with self._client.audio.speech.with_streaming_response.create(
|
|
input=text,
|
|
model=self._model,
|
|
voice=self._voice,
|
|
response_format="pcm",
|
|
) as r:
|
|
if r.status_code != 200:
|
|
error = await r.text()
|
|
logger.error(
|
|
f"{self} error getting audio (status: {r.status_code}, error: {error})")
|
|
yield ErrorFrame(f"Error getting audio (status: {r.status_code}, error: {error})")
|
|
return
|
|
async for chunk in r.iter_bytes(8192):
|
|
if len(chunk) > 0:
|
|
await self.stop_ttfb_metrics()
|
|
frame = AudioRawFrame(chunk, self.sample_rate, 1)
|
|
yield frame
|
|
except BadRequestError as e:
|
|
logger.exception(f"{self} error generating TTS: {e}")
|
|
|