|
|
from fastapi import FastAPI, File, UploadFile, Request |
|
|
from fastapi.responses import StreamingResponse, JSONResponse |
|
|
from transformers import pipeline |
|
|
from TTS.api import TTS |
|
|
import whisper |
|
|
import torch |
|
|
from io import BytesIO |
|
|
from PIL import Image |
|
|
import base64 |
|
|
import os |
|
|
|
|
|
app = FastAPI(title="NasFit AI Server") |
|
|
|
|
|
|
|
|
API_KEY = os.getenv("API_KEY", "nasfit_secret_key") |
|
|
|
|
|
|
|
|
print("Cargando modelos...") |
|
|
|
|
|
chat_pipe = pipeline("text-generation", model="meta-llama/Meta-Llama-3-8B-Instruct") |
|
|
vision_pipe = pipeline("image-text-to-text", model="lmms-lab/llava-onevision-1.6-7b-hf") |
|
|
whisper_model = whisper.load_model("small") |
|
|
tts = TTS("coqui/XTTS-v2") |
|
|
|
|
|
print("β
Modelos listos.") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
async def check_auth(request: Request): |
|
|
auth = request.headers.get("Authorization", "") |
|
|
if not auth or not auth.startswith("Bearer ") or auth.split(" ")[1] != API_KEY: |
|
|
return False |
|
|
return True |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@app.post("/v1/chat/completions") |
|
|
async def chat_endpoint(request: Request): |
|
|
if not await check_auth(request): |
|
|
return JSONResponse({"error": "Unauthorized"}, status_code=401) |
|
|
payload = await request.json() |
|
|
messages = payload.get("messages", []) |
|
|
model = payload.get("model", "llama3") |
|
|
|
|
|
|
|
|
image_content = None |
|
|
text_content = "" |
|
|
for msg in messages: |
|
|
content = msg.get("content", "") |
|
|
if isinstance(content, list): |
|
|
for c in content: |
|
|
if c.get("type") == "text": |
|
|
text_content += c.get("text", "") |
|
|
elif c.get("type") == "image_url": |
|
|
img_url = c["image_url"]["url"] |
|
|
if img_url.startswith("data:image"): |
|
|
image_content = Image.open(BytesIO(base64.b64decode(img_url.split(",")[1]))) |
|
|
else: |
|
|
text_content += content |
|
|
|
|
|
if image_content: |
|
|
response = vision_pipe(text_content, images=image_content)[0]["generated_text"] |
|
|
else: |
|
|
response = chat_pipe(text_content, max_new_tokens=300)[0]["generated_text"] |
|
|
|
|
|
return {"choices": [{"message": {"content": response}}]} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@app.post("/v1/audio/transcriptions") |
|
|
async def transcribe(request: Request, file: UploadFile = File(...)): |
|
|
if not await check_auth(request): |
|
|
return JSONResponse({"error": "Unauthorized"}, status_code=401) |
|
|
audio = await file.read() |
|
|
with open("temp.wav", "wb") as f: |
|
|
f.write(audio) |
|
|
result = whisper_model.transcribe("temp.wav") |
|
|
return {"text": result["text"]} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@app.post("/v1/audio/speech") |
|
|
async def tts_endpoint(request: Request): |
|
|
if not await check_auth(request): |
|
|
return JSONResponse({"error": "Unauthorized"}, status_code=401) |
|
|
payload = await request.json() |
|
|
text = payload.get("input", "") |
|
|
voice = payload.get("voice", "es_male_01") |
|
|
tts.tts_to_file(text=text, file_path="output.wav", speaker=voice) |
|
|
with open("output.wav", "rb") as f: |
|
|
audio = f.read() |
|
|
return StreamingResponse(BytesIO(audio), media_type="audio/wav") |
|
|
|