test3 / tests /enterprise /enterprise_hooks /test_managed_files.py
DesertWolf's picture
Upload folder using huggingface_hub
447ebeb verified
import json
import os
import sys
import pytest
from fastapi import HTTPException
from fastapi.testclient import TestClient
sys.path.insert(
0, os.path.abspath("../../..")
) # Adds the parent directory to the system path
from unittest.mock import AsyncMock, MagicMock, patch
from enterprise.enterprise_hooks.managed_files import _PROXY_LiteLLMManagedFiles
from litellm.caching import DualCache
from litellm.proxy.openai_files_endpoints.common_utils import (
_is_base64_encoded_unified_file_id,
)
from litellm.types.utils import SpecialEnums
def test_get_file_ids_from_messages():
proxy_managed_files = _PROXY_LiteLLMManagedFiles(
DualCache(), prisma_client=MagicMock()
)
messages = [
{
"role": "user",
"content": [
{"type": "text", "text": "What is in this recording?"},
{
"type": "file",
"file": {
"file_id": "bGl0ZWxsbV9wcm94eTphcHBsaWNhdGlvbi9wZGY7dW5pZmllZF9pZCxmYzdmMmVhNS0wZjUwLTQ5ZjYtODljMS03ZTZhNTRiMTIxMzg",
},
},
],
},
]
file_ids = proxy_managed_files.get_file_ids_from_messages(messages)
assert file_ids == [
"bGl0ZWxsbV9wcm94eTphcHBsaWNhdGlvbi9wZGY7dW5pZmllZF9pZCxmYzdmMmVhNS0wZjUwLTQ5ZjYtODljMS03ZTZhNTRiMTIxMzg"
]
@pytest.mark.asyncio
async def test_async_pre_call_hook_batch_retrieve():
from litellm.proxy._types import UserAPIKeyAuth
prisma_client = AsyncMock()
return_value = MagicMock()
return_value.created_by = "123"
prisma_client.db.litellm_managedobjecttable.find_first.return_value = return_value
proxy_managed_files = _PROXY_LiteLLMManagedFiles(
DualCache(), prisma_client=prisma_client
)
data = {
"user_api_key_dict": UserAPIKeyAuth(
user_id="123", parent_otel_span=MagicMock()
),
"data": {
"batch_id": "bGl0ZWxsbV9wcm94eTttb2RlbF9pZDpteS1nZW5lcmFsLWF6dXJlLWRlcGxveW1lbnQ7bGxtX2JhdGNoX2lkOmJhdGNoX2EzMjJiNmJhLWFjN2UtNDg4OC05MjljLTFhZDM0NDJmMDZlZA",
},
"call_type": "aretrieve_batch",
"cache": MagicMock(),
}
response = await proxy_managed_files.async_pre_call_hook(**data)
assert response["batch_id"] == "batch_a322b6ba-ac7e-4888-929c-1ad3442f06ed"
assert response["model"] == "my-general-azure-deployment"
# def test_list_managed_files():
# proxy_managed_files = _PROXY_LiteLLMManagedFiles(DualCache())
# # Create some test files
# file1 = proxy_managed_files.create_file(
# file=("test1.txt", b"test content 1", "text/plain"),
# purpose="assistants"
# )
# file2 = proxy_managed_files.create_file(
# file=("test2.pdf", b"test content 2", "application/pdf"),
# purpose="assistants"
# )
# # List all files
# files = proxy_managed_files.list_files()
# # Verify response
# assert len(files) == 2
# assert all(f.id.startswith(SpecialEnums.LITELM_MANAGED_FILE_ID_PREFIX.value) for f in files)
# assert any(f.filename == "test1.txt" for f in files)
# assert any(f.filename == "test2.pdf" for f in files)
# assert all(f.purpose == "assistants" for f in files)
# def test_retrieve_managed_file():
# proxy_managed_files = _PROXY_LiteLLMManagedFiles(DualCache())
# # Create a test file
# test_content = b"test content for retrieve"
# created_file = proxy_managed_files.create_file(
# file=("test.txt", test_content, "text/plain"),
# purpose="assistants"
# )
# # Retrieve the file
# retrieved_file = proxy_managed_files.retrieve_file(created_file.id)
# # Verify response
# assert retrieved_file.id == created_file.id
# assert retrieved_file.filename == "test.txt"
# assert retrieved_file.purpose == "assistants"
# assert retrieved_file.bytes == len(test_content)
# assert retrieved_file.status == "uploaded"
# def test_delete_managed_file():
# proxy_managed_files = _PROXY_LiteLLMManagedFiles(DualCache())
# # Create a test file
# created_file = proxy_managed_files.create_file(
# file=("test.txt", b"test content", "text/plain"),
# purpose="assistants"
# )
# # Delete the file
# deleted_file = proxy_managed_files.delete_file(created_file.id)
# # Verify deletion
# assert deleted_file.id == created_file.id
# assert deleted_file.deleted == True
# # Verify file is no longer retrievable
# with pytest.raises(Exception):
# proxy_managed_files.retrieve_file(created_file.id)
# # Verify file is not in list
# files = proxy_managed_files.list_files()
# assert created_file.id not in [f.id for f in files]
# def test_retrieve_nonexistent_file():
# proxy_managed_files = _PROXY_LiteLLMManagedFiles(DualCache())
# # Try to retrieve a non-existent file
# with pytest.raises(Exception):
# proxy_managed_files.retrieve_file("nonexistent-file-id")
# def test_delete_nonexistent_file():
# proxy_managed_files = _PROXY_LiteLLMManagedFiles(DualCache())
# # Try to delete a non-existent file
# with pytest.raises(Exception):
# proxy_managed_files.delete_file("nonexistent-file-id")
# def test_list_files_with_purpose_filter():
# proxy_managed_files = _PROXY_LiteLLMManagedFiles(DualCache())
# # Create files with different purposes
# file1 = proxy_managed_files.create_file(
# file=("test1.txt", b"test content 1", "text/plain"),
# purpose="assistants"
# )
# file2 = proxy_managed_files.create_file(
# file=("test2.pdf", b"test content 2", "application/pdf"),
# purpose="batch"
# )
# # List files with purpose filter
# assistant_files = proxy_managed_files.list_files(purpose="assistants")
# batch_files = proxy_managed_files.list_files(purpose="batch")
# # Verify filtering
# assert len(assistant_files) == 1
# assert len(batch_files) == 1
# assert assistant_files[0].id == file1.id
# assert batch_files[0].id == file2.id
@pytest.mark.asyncio
async def test_async_post_call_success_hook_for_unified_finetuning_job():
from litellm.types.utils import LiteLLMFineTuningJob
unified_file_id = "bGl0ZWxsbV9wcm94eTphcHBsaWNhdGlvbi9vY3RldC1zdHJlYW07dW5pZmllZF9pZCxiZTQ0ZDVlYi1mNDU3LTRiNzktOWM4My01N2QxMTMxYWM0YzY7dGFyZ2V0X21vZGVsX25hbWVzLGdwdC00LjEtb3BlbmFpO2xsbV9vdXRwdXRfZmlsZV9pZCxmaWxlLURKMnQ0OWZlQ2NTQk5vNG9oekZ6NGc7bGxtX291dHB1dF9maWxlX21vZGVsX2lkLGRiNjY5ODcwNzdkZTdmYzZjNzAzY2Y1MDczMGU2MmNkOWQ3YTU1N2NlNjVmMDUzNTFkYTM4YTA3ZjBlZDEyNzQ"
provider_ft_job = LiteLLMFineTuningJob(
object="fine_tuning.job",
id="ftjob-0kEBV5b4sPrFcMnuzmYSzU1G",
model="gpt-3.5-turbo-0613",
created_at=1692779769,
finished_at=None,
fine_tuned_model=None,
organization_id="org-dUVLhaAQ37YCGwVC2QVY8sdB",
result_files=[],
status="validating_files",
validation_file=None,
training_file="file-azQuKMLAmiFdEjxpCcbI11zF",
hyperparameters={"n_epochs": 8},
trained_tokens=None,
seed=0,
)
provider_ft_job._hidden_params = {
"unified_file_id": unified_file_id,
"model_id": "gpt-3.5-turbo-0613",
}
proxy_managed_files = _PROXY_LiteLLMManagedFiles(
DualCache(), prisma_client=MagicMock()
)
data = {
"user_api_key_dict": {"parent_otel_span": MagicMock()},
}
response = await proxy_managed_files.async_post_call_success_hook(
data=data,
user_api_key_dict=MagicMock(),
response=provider_ft_job,
)
assert isinstance(response, LiteLLMFineTuningJob)
assert _is_base64_encoded_unified_file_id(response.id)
@pytest.mark.asyncio
async def test_async_pre_call_hook_for_unified_finetuning_job():
from litellm.proxy._types import UserAPIKeyAuth
prisma_client = AsyncMock()
return_value = MagicMock()
return_value.created_by = "123"
prisma_client.db.litellm_managedobjecttable.find_first.return_value = return_value
proxy_managed_files = _PROXY_LiteLLMManagedFiles(
DualCache(), prisma_client=prisma_client
)
data = {
"user_api_key_dict": UserAPIKeyAuth(
user_id="123", parent_otel_span=MagicMock()
),
"data": {
"fine_tuning_job_id": "bGl0ZWxsbV9wcm94eTttb2RlbF9pZDo0OTIxODU4MWY3OGViZTllZjE4NDE0ZmE0ZjdmYjlmYTc0YzA5NWVkMTEyY2E4NDBkZDU2ZGZmZTliZDMwZGQxO2dlbmVyaWNfcmVzcG9uc2VfaWQ6ZnRqb2ItalRCeXM3YlZzYnlaRE93TDlHbHBZcVhS",
},
"call_type": "acancel_fine_tuning_job",
"cache": MagicMock(),
}
response = await proxy_managed_files.async_pre_call_hook(**data)
assert response["fine_tuning_job_id"] == "ftjob-jTBys7bVsbyZDOwL9GlpYqXR"
@pytest.mark.asyncio
@pytest.mark.parametrize("call_type", ["afile_content", "afile_delete"])
async def test_can_user_call_unified_file_id(call_type):
"""
Test that on file retrieve, delete we check if the user has access to the file
"""
from litellm.proxy._types import UserAPIKeyAuth
prisma_client = AsyncMock()
return_value = MagicMock()
return_value.created_by = "123"
prisma_client.db.litellm_managedfiletable.find_first.return_value = return_value
proxy_managed_files = _PROXY_LiteLLMManagedFiles(
MagicMock(), prisma_client=prisma_client
)
unified_file_id = "bGl0ZWxsbV9wcm94eTphcHBsaWNhdGlvbi9vY3RldC1zdHJlYW07dW5pZmllZF9pZCxmMTNlNDAzZS01YWM3LTRhZjktOGQzNS0wNDgwZDMxOTgyYTg7dGFyZ2V0X21vZGVsX25hbWVzLGdwdC00by1taW5pLW9wZW5haTtsbG1fb3V0cHV0X2ZpbGVfaWQsZmlsZS1Ib3UxZDFXc3c1SDNKcjFMYllpZDJiO2xsbV9vdXRwdXRfZmlsZV9tb2RlbF9pZCxmODBiNWU2NzQ1NzdkNjkyMjM4YmVhNTIxZDdiMGI5ZGYyY2FmMTEwMTU2YmU5YzBjM2NjMmNkNTBjOTM1ZDI0"
with pytest.raises(HTTPException) as e:
await proxy_managed_files.async_pre_call_hook(
user_api_key_dict=UserAPIKeyAuth(
user_id="456", parent_otel_span=MagicMock()
),
cache=MagicMock(),
data={"file_id": unified_file_id},
call_type=call_type,
)