Luna / apiServer.py
Chun121's picture
Upload 4 files
ab343a9 verified
from fastapi import FastAPI, HTTPException, Body
from pydantic import BaseModel
from typing import List, Optional
from model_handler import ModelHandler
import uvicorn
from pathlib import Path
import json
app = FastAPI(title="Luna Chat API")
model_handler = ModelHandler()
class Message(BaseModel):
role: str
content: str
class ChatRequest(BaseModel):
messages: List[Message]
temperature: Optional[float] = 0.7
class Config:
schema_extra = {
"example": {
"messages": [{"role": "user", "content": "Hello"}],
"temperature": 0.7
}
}
def validate_temperature(self):
if self.temperature < 0.1 or self.temperature > 2.0:
raise ValueError("Temperature must be between 0.1 and 2.0")
class ChatResponse(BaseModel):
response: str
history: List[Message]
class ExportRequest(BaseModel):
history: List[Message]
filename: Optional[str] = None
class ImportRequest(BaseModel):
filepath: str
@app.post("/chat", response_model=ChatResponse)
async def chat(request: ChatRequest):
try:
request.validate_temperature()
response = model_handler.generate_response(
[msg.dict() for msg in request.messages],
temperature=request.temperature
)
# Assuming ChatResponse expects the assistant's reply only
assistant_response = next((msg["content"] for msg in response if msg["role"] == "assistant"), "")
history = [Message(**msg) for msg in response]
return ChatResponse(response=assistant_response, history=history)
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.post("/export")
async def export_chat(request: ExportRequest):
try:
filepath = model_handler.save_chat_history(
[msg.dict() for msg in request.history],
request.filename
)
return {"status": "success", "filepath": filepath}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.post("/import")
async def import_chat(request: ImportRequest):
try:
history = model_handler.load_chat_history(request.filepath)
return {"status": "success", "history": history}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.get("/model/status")
async def model_status():
return {
"status": "loaded" if model_handler.model is not None else "not_loaded",
"model_hash": model_handler.get_model_hash()
}
if __name__ == "__main__":
uvicorn.run(app, host="0.0.0.0", port=8000, reload=True) # Ensure port is 8000