|
from fastapi import FastAPI, Depends |
|
from fastapi.responses import StreamingResponse |
|
from pydantic import BaseModel |
|
from typing import Annotated |
|
from mistralai import Mistral |
|
from auth import verify_token |
|
import os |
|
|
|
app = FastAPI() |
|
|
|
mistral = os.environ.get('MISTRAL_KEY', '') |
|
mistral_client = Mistral(api_key=mistral) |
|
|
|
@app.get("/") |
|
def hello(): |
|
return {"Hello": "World!"} |
|
|
|
class MistralRequest(BaseModel): |
|
model: str |
|
prompt: str |
|
|
|
@app.post("/mistral") |
|
async def mistral(request: MistralRequest, token: Annotated[str, Depends(verify_token)]): |
|
async def generate(): |
|
response = await mistral_client.chat.stream_async( |
|
model=request.model, |
|
messages=[ |
|
{ |
|
"role": "user", |
|
"content": request.prompt, |
|
} |
|
], |
|
) |
|
async for chunk in response: |
|
if chunk.data.choices[0].delta.content is not None: |
|
yield chunk.data.choices[0].delta.content |
|
|
|
return StreamingResponse(generate(), media_type="text/plain") |