|
from fastapi import APIRouter, Depends, HTTPException |
|
from fastapi.responses import StreamingResponse |
|
from pydantic import BaseModel |
|
from typing import Annotated, List |
|
from mistralai import Mistral |
|
from auth import verify_token |
|
import os |
|
from schemas.common import APIResponse, ChatMessage, ChatMemoryRequest |
|
import asyncio |
|
import time |
|
|
|
router = APIRouter(prefix="/mistral", tags=["mistral"]) |
|
|
|
mistral_key = os.environ.get('MISTRAL_KEY', '') |
|
if not mistral_key: |
|
raise RuntimeError("MISTRAL_KEY environment variable not set.") |
|
mistral_client = Mistral(api_key=mistral_key) |
|
|
|
|
|
last_mistral_call_time = 0 |
|
mistral_call_lock = asyncio.Lock() |
|
MIN_INTERVAL = 1.0 |
|
|
|
class LLMRequest(BaseModel): |
|
model: str |
|
prompt: str |
|
|
|
@router.post("/chat-stream") |
|
async def mistral_chat_stream(request: LLMRequest, token: Annotated[str, Depends(verify_token)]): |
|
async def generate(): |
|
global last_mistral_call_time |
|
async with mistral_call_lock: |
|
current_time = time.monotonic() |
|
elapsed = current_time - last_mistral_call_time |
|
if elapsed < MIN_INTERVAL: |
|
await asyncio.sleep(MIN_INTERVAL - elapsed) |
|
last_mistral_call_time = time.monotonic() |
|
|
|
try: |
|
response = await mistral_client.chat.stream_async( |
|
model=request.model, |
|
messages=[ |
|
{ |
|
"role": "user", |
|
"content": request.prompt, |
|
} |
|
], |
|
) |
|
async for chunk in response: |
|
|
|
if hasattr(chunk, 'choices') and chunk.choices: |
|
if chunk.choices[0].delta.content is not None: |
|
yield chunk.choices[0].delta.content |
|
|
|
|
|
|
|
except Exception as e: |
|
|
|
print(f"Error during Mistral stream: {e}") |
|
yield f"Error: {str(e)}" |
|
|
|
return StreamingResponse(generate(), media_type="text/plain") |
|
|
|
@router.post("/chat", response_model=APIResponse) |
|
async def mistral_chat(request: LLMRequest, token: Annotated[str, Depends(verify_token)]): |
|
global last_mistral_call_time |
|
async with mistral_call_lock: |
|
current_time = time.monotonic() |
|
elapsed = current_time - last_mistral_call_time |
|
if elapsed < MIN_INTERVAL: |
|
await asyncio.sleep(MIN_INTERVAL - elapsed) |
|
last_mistral_call_time = time.monotonic() |
|
|
|
try: |
|
response = await mistral_client.chat.complete_async( |
|
model=request.model, |
|
messages=[ |
|
{ |
|
"role": "user", |
|
"content": request.prompt, |
|
} |
|
], |
|
) |
|
if response.choices and response.choices[0].message: |
|
content = response.choices[0].message.content |
|
return APIResponse(success=True, data={"response": content}) |
|
else: |
|
return APIResponse(success=False, error="No response content received from Mistral.", data=response.dict()) |
|
except Exception as e: |
|
print(f"Error calling Mistral chat completion: {e}") |
|
raise HTTPException(status_code=500, detail=f"Mistral API error: {str(e)}") |
|
|
|
@router.post("/chat-with-memory", response_model=APIResponse) |
|
async def mistral_chat_with_memory(request: ChatMemoryRequest, token: Annotated[str, Depends(verify_token)]): |
|
global last_mistral_call_time |
|
async with mistral_call_lock: |
|
current_time = time.monotonic() |
|
elapsed = current_time - last_mistral_call_time |
|
if elapsed < MIN_INTERVAL: |
|
await asyncio.sleep(MIN_INTERVAL - elapsed) |
|
last_mistral_call_time = time.monotonic() |
|
|
|
try: |
|
|
|
messages_dict = [msg.dict() for msg in request.messages] |
|
|
|
response = await mistral_client.chat.complete_async( |
|
model=request.model, |
|
messages=messages_dict, |
|
) |
|
if response.choices and response.choices[0].message: |
|
content = response.choices[0].message.content |
|
return APIResponse(success=True, data={"response": content}) |
|
else: |
|
return APIResponse(success=False, error="No response content received from Mistral.", data=response.dict()) |
|
except Exception as e: |
|
print(f"Error calling Mistral chat completion with memory: {e}") |
|
raise HTTPException(status_code=500, detail=f"Mistral API error: {str(e)}") |