|
import os |
|
from typing import List |
|
|
|
from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler |
|
if os.getenv('OPENAI_API_TYPE') == 'azure': |
|
from langchain.chat_models import AzureChatOpenAI |
|
else: |
|
from langchain.chat_models import ChatOpenAI |
|
from langchain.schema import BaseMessage, HumanMessage |
|
|
|
from realtime_ai_character.database.chroma import get_chroma |
|
from realtime_ai_character.llm.base import AsyncCallbackAudioHandler, AsyncCallbackTextHandler, LLM |
|
from realtime_ai_character.logger import get_logger |
|
from realtime_ai_character.utils import Character |
|
|
|
logger = get_logger(__name__) |
|
|
|
|
|
class OpenaiLlm(LLM): |
|
def __init__(self, model): |
|
if os.getenv('OPENAI_API_TYPE') == 'azure': |
|
self.chat_open_ai = AzureChatOpenAI( |
|
deployment_name=os.getenv( |
|
'OPENAI_API_MODEL_DEPLOYMENT_NAME', 'gpt-35-turbo'), |
|
model=model, |
|
temperature=0.5, |
|
streaming=True |
|
) |
|
else: |
|
self.chat_open_ai = ChatOpenAI( |
|
model=model, |
|
temperature=0.5, |
|
streaming=True |
|
) |
|
self.db = get_chroma() |
|
|
|
async def achat(self, |
|
history: List[BaseMessage], |
|
user_input: str, |
|
user_input_template: str, |
|
callback: AsyncCallbackTextHandler, |
|
audioCallback: AsyncCallbackAudioHandler, |
|
character: Character) -> str: |
|
|
|
print('user_input=', user_input) |
|
context = self._generate_context(user_input, character) |
|
|
|
|
|
history.append(HumanMessage(content=user_input_template.format( |
|
context=context, query=user_input))) |
|
|
|
|
|
response = await self.chat_open_ai.agenerate( |
|
[history], callbacks=[callback, audioCallback, StreamingStdOutCallbackHandler()]) |
|
logger.info(f'Response: {response}') |
|
return response.generations[0][0].text |
|
|
|
def _generate_context(self, query, character: Character) -> str: |
|
print('query=', query) |
|
docs = self.db.similarity_search(query) |
|
docs = [d for d in docs if d.metadata['character_name'] == character.name] |
|
logger.info(f'Found {len(docs)} documents') |
|
|
|
context = '\n'.join([d.page_content for d in docs]) |
|
return context |
|
|