restaurants / azure_openai.py
briankchan's picture
Add app
e6e69dc
import os
from typing import Self
from chainlit import LLMSettings
from chainlit.telemetry import trace_event
from chainlit.types import CompletionRequest
from pydantic.dataclasses import dataclass
from starlette.responses import PlainTextResponse
@dataclass
class AzureOpenaiSettings(LLMSettings):
api_type: str = 'azure'
api_base: str = ''
engine: str = ''
api_version: str = '2023-05-15'
def to_settings_dict(self):
return {
**super().to_settings_dict(),
"api_type": self.api_type,
"api_base": self.api_base,
"api_version": self.api_version,
"engine": self.engine,
}
@classmethod
def load_from_env(cls: type[Self], *args, **kwargs) -> Self:
return cls(
*args,
api_type='azure',
api_base=os.environ.get('AZURE_OPENAI_ENDPOINT'),
engine=os.environ.get('AZURE_OPENAI_DEPLOYMENT'),
api_version=os.environ.get('AZURE_OPENAI_VERSION', '2023-05-15'),
**kwargs,
)
@dataclass
class AzureOpenaiEmbeddings:
api_type: str = 'azure'
api_base: str = ''
engine: str = ''
api_version: str = '2023-05-15'
def to_settings_dict(self):
return {
"api_type": self.api_type,
"api_base": self.api_base,
"api_version": self.api_version,
"engine": self.engine,
}
@classmethod
def load_from_env(cls: type[Self], *args, **kwargs) -> Self:
return cls(
*args,
api_type='azure',
api_base=os.environ.get('AZURE_OPENAI_ENDPOINT'),
engine=os.environ.get('AZURE_OPENAI_EMBEDDINGS_DEPLOYMENT'),
api_version=os.environ.get('AZURE_OPENAI_VERSION', '2023-05-15'),
**kwargs,
)
def patch_chainlit():
from chainlit.server import app
# replace playground's completion endpoint with one that uses custom openai settings
app.router.routes = list(filter(lambda route: route.path != '/completion', app.router.routes))
@app.post("/completion")
async def completion(request: CompletionRequest):
"""Handle a completion request from the prompt playground."""
import openai
trace_event("completion")
api_key = request.userEnv.get("OPENAI_API_KEY", os.environ.get("OPENAI_API_KEY"))
stop = request.settings.stop
# OpenAI doesn't support an empty stop array, clear it
if isinstance(stop, list) and len(stop) == 0:
stop = None
response = await openai.ChatCompletion.acreate(
api_key=api_key,
messages=[{"role": "user", "content": request.prompt}],
stop=stop,
# **completion.settings.to_settings_dict(),
# HACK: hard-code llm settings
**dict(api_type='azure', api_base=os.environ.get('AZURE_OPENAI_ENDPOINT'),
engine=os.environ.get('AZURE_OPENAI_DEPLOYMENT'),
api_version=os.environ.get('AZURE_OPENAI_VERSION', '2023-05-15')
),
)
return PlainTextResponse(content=response["choices"][0]["message"]["content"])