Spaces:
Configuration error
Configuration error
# What is this? | |
## Unit Tests for OpenAI Assistants API | |
import json | |
import os | |
import sys | |
import traceback | |
from dotenv import load_dotenv | |
load_dotenv() | |
sys.path.insert( | |
0, os.path.abspath("../..") | |
) # Adds the parent directory to the system path | |
import asyncio | |
import logging | |
import pytest | |
from openai.types.beta.assistant import Assistant | |
from typing_extensions import override | |
import litellm | |
from litellm import create_thread, get_thread | |
from litellm.llms.openai.openai import ( | |
AssistantEventHandler, | |
AsyncAssistantEventHandler, | |
AsyncCursorPage, | |
MessageData, | |
OpenAIAssistantsAPI, | |
) | |
from litellm.llms.openai.openai import OpenAIMessage as Message | |
from litellm.llms.openai.openai import SyncCursorPage, Thread | |
""" | |
V0 Scope: | |
- Add Message -> `/v1/threads/{thread_id}/messages` | |
- Run Thread -> `/v1/threads/{thread_id}/run` | |
""" | |
def _add_azure_related_dynamic_params(data: dict) -> dict: | |
data["api_version"] = "2024-02-15-preview" | |
data["api_base"] = os.getenv("AZURE_ASSISTANTS_API_BASE") | |
data["api_key"] = os.getenv("AZURE_ASSISTANTS_API_KEY") | |
return data | |
async def test_get_assistants(provider, sync_mode): | |
data = { | |
"custom_llm_provider": provider, | |
} | |
if provider == "azure": | |
data = _add_azure_related_dynamic_params(data) | |
if sync_mode == True: | |
assistants = litellm.get_assistants(**data) | |
assert isinstance(assistants, SyncCursorPage) | |
else: | |
assistants = await litellm.aget_assistants(**data) | |
assert isinstance(assistants, AsyncCursorPage) | |
async def test_create_delete_assistants(provider, sync_mode): | |
litellm.ssl_verify = False | |
litellm._turn_on_debug() | |
data = { | |
"custom_llm_provider": provider, | |
"model": "gpt-4.5-preview", | |
"instructions": "You are a personal math tutor. When asked a question, write and run Python code to answer the question.", | |
"name": "Math Tutor", | |
"tools": [{"type": "code_interpreter"}], | |
} | |
if provider == "azure": | |
data = _add_azure_related_dynamic_params(data) | |
if sync_mode == True: | |
assistant = litellm.create_assistants(**data) | |
print("New assistants", assistant) | |
assert isinstance(assistant, Assistant) | |
assert ( | |
assistant.instructions | |
== "You are a personal math tutor. When asked a question, write and run Python code to answer the question." | |
) | |
assert assistant.id is not None | |
# delete the created assistant | |
delete_data = { | |
"custom_llm_provider": provider, | |
"assistant_id": assistant.id, | |
} | |
if provider == "azure": | |
delete_data = _add_azure_related_dynamic_params(delete_data) | |
response = litellm.delete_assistant(**delete_data) | |
print("Response deleting assistant", response) | |
assert response.id == assistant.id | |
else: | |
assistant = await litellm.acreate_assistants(**data) | |
print("New assistants", assistant) | |
assert isinstance(assistant, Assistant) | |
assert ( | |
assistant.instructions | |
== "You are a personal math tutor. When asked a question, write and run Python code to answer the question." | |
) | |
assert assistant.id is not None | |
# delete the created assistant | |
delete_data = { | |
"custom_llm_provider": provider, | |
"assistant_id": assistant.id, | |
} | |
if provider == "azure": | |
delete_data = _add_azure_related_dynamic_params(delete_data) | |
response = await litellm.adelete_assistant(**delete_data) | |
print("Response deleting assistant", response) | |
assert response.id == assistant.id | |
async def test_create_thread_litellm(sync_mode, provider) -> Thread: | |
message: MessageData = {"role": "user", "content": "Hey, how's it going?"} # type: ignore | |
data = { | |
"custom_llm_provider": provider, | |
"message": [message], | |
} | |
if provider == "azure": | |
data = _add_azure_related_dynamic_params(data) | |
if sync_mode: | |
new_thread = create_thread(**data) | |
else: | |
new_thread = await litellm.acreate_thread(**data) | |
assert isinstance( | |
new_thread, Thread | |
), f"type of thread={type(new_thread)}. Expected Thread-type" | |
return new_thread | |
async def test_get_thread_litellm(provider, sync_mode): | |
new_thread = test_create_thread_litellm(sync_mode, provider) | |
if asyncio.iscoroutine(new_thread): | |
_new_thread = await new_thread | |
else: | |
_new_thread = new_thread | |
data = { | |
"custom_llm_provider": provider, | |
"thread_id": _new_thread.id, | |
} | |
if provider == "azure": | |
data = _add_azure_related_dynamic_params(data) | |
if sync_mode: | |
received_thread = get_thread(**data) | |
else: | |
received_thread = await litellm.aget_thread(**data) | |
assert isinstance( | |
received_thread, Thread | |
), f"type of thread={type(received_thread)}. Expected Thread-type" | |
return new_thread | |
async def test_add_message_litellm(sync_mode, provider): | |
message: MessageData = {"role": "user", "content": "Hey, how's it going?"} # type: ignore | |
new_thread = test_create_thread_litellm(sync_mode, provider) | |
if asyncio.iscoroutine(new_thread): | |
_new_thread = await new_thread | |
else: | |
_new_thread = new_thread | |
# add message to thread | |
message: MessageData = {"role": "user", "content": "Hey, how's it going?"} # type: ignore | |
data = {"custom_llm_provider": provider, "thread_id": _new_thread.id, **message} | |
if provider == "azure": | |
data = _add_azure_related_dynamic_params(data) | |
if sync_mode: | |
added_message = litellm.add_message(**data) | |
else: | |
added_message = await litellm.a_add_message(**data) | |
print(f"added message: {added_message}") | |
assert isinstance(added_message, Message) | |
# | |
# | |
async def test_aarun_thread_litellm(sync_mode, provider, is_streaming): | |
""" | |
- Get Assistants | |
- Create thread | |
- Create run w/ Assistants + Thread | |
""" | |
import openai | |
try: | |
get_assistants_data = { | |
"custom_llm_provider": provider, | |
} | |
if provider == "azure": | |
get_assistants_data = _add_azure_related_dynamic_params(get_assistants_data) | |
if sync_mode: | |
assistants = litellm.get_assistants(**get_assistants_data) | |
else: | |
assistants = await litellm.aget_assistants(**get_assistants_data) | |
## get the first assistant ### | |
try: | |
assistant_id = assistants.data[0].id | |
except IndexError: | |
pytest.skip("No assistants found") | |
new_thread = test_create_thread_litellm(sync_mode=sync_mode, provider=provider) | |
if asyncio.iscoroutine(new_thread): | |
_new_thread = await new_thread | |
else: | |
_new_thread = new_thread | |
thread_id = _new_thread.id | |
# add message to thread | |
message: MessageData = {"role": "user", "content": "Hey, how's it going?"} # type: ignore | |
data = {"custom_llm_provider": provider, "thread_id": _new_thread.id, **message} | |
if provider == "azure": | |
data = _add_azure_related_dynamic_params(data) | |
if sync_mode: | |
added_message = litellm.add_message(**data) | |
if is_streaming: | |
run = litellm.run_thread_stream(assistant_id=assistant_id, **data) | |
with run as run: | |
assert isinstance(run, AssistantEventHandler) | |
print(run) | |
run.until_done() | |
else: | |
run = litellm.run_thread( | |
assistant_id=assistant_id, stream=is_streaming, **data | |
) | |
if run.status == "completed": | |
messages = litellm.get_messages( | |
thread_id=_new_thread.id, custom_llm_provider=provider | |
) | |
assert isinstance(messages.data[0], Message) | |
else: | |
pytest.fail( | |
"An unexpected error occurred when running the thread, {}".format( | |
run | |
) | |
) | |
else: | |
added_message = await litellm.a_add_message(**data) | |
if is_streaming: | |
run = litellm.arun_thread_stream(assistant_id=assistant_id, **data) | |
async with run as run: | |
print(f"run: {run}") | |
assert isinstance( | |
run, | |
AsyncAssistantEventHandler, | |
) | |
print(run) | |
await run.until_done() | |
else: | |
run = await litellm.arun_thread( | |
custom_llm_provider=provider, | |
thread_id=thread_id, | |
assistant_id=assistant_id, | |
) | |
if run.status == "completed": | |
messages = await litellm.aget_messages( | |
thread_id=_new_thread.id, custom_llm_provider=provider | |
) | |
assert isinstance(messages.data[0], Message) | |
else: | |
pytest.fail( | |
"An unexpected error occurred when running the thread, {}".format( | |
run | |
) | |
) | |
except openai.APIError as e: | |
pass | |