File size: 2,813 Bytes
c4e9412 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 |
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
import logging
from QwenChat import EnterpriseQwenChat # Import the chat system from QwenChat
# Initialize FastAPI app
app = FastAPI()
# Logging setup
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(message)s")
# Initialize QwenChat system
model_directory = "./qwen" # Path to the model directory
chat_system = EnterpriseQwenChat(model_directory=model_directory)
# Define input schema for the API
class ChatRequest(BaseModel):
user_input: str
@app.get("/")
def health_check():
"""Health check endpoint."""
return {"status": "Healthy", "message": "Model API is live!"}
@app.post("/chat")
def chat(request: ChatRequest):
"""
Chat endpoint: Handles user input and returns the model's response.
"""
try:
user_input = request.user_input.strip()
if not user_input:
raise HTTPException(status_code=400, detail="User input cannot be empty.")
# Add user input to the conversation
chat_system.conversation_manager.add_turn("user", user_input)
# Generate AI response
prompt = chat_system.conversation_manager.get_prompt()
response = chat_system.response_generator.generate_response(
prompt, len(chat_system.conversation_manager.turns)
)
# Add response to conversation history
chat_system.conversation_manager.add_turn("assistant", response)
return {
"response": response,
"conversation": [
{"role": turn.role, "content": turn.content}
for turn in chat_system.conversation_manager.turns
],
}
except Exception as e:
logging.error(f"Error in chat endpoint: {str(e)}")
raise HTTPException(status_code=500, detail=str(e))
@app.post("/clear")
def clear_conversation():
"""
Clear conversation history.
"""
try:
chat_system.conversation_manager.turns.clear()
return {"message": "Conversation history cleared."}
except Exception as e:
logging.error(f"Error in clear endpoint: {str(e)}")
raise HTTPException(status_code=500, detail=str(e))
@app.get("/save")
def save_conversation():
"""
Save the current conversation to a file.
"""
try:
filename = chat_system.save_conversation()
return {"message": f"Conversation saved to {filename}"}
except Exception as e:
logging.error(f"Error in save endpoint: {str(e)}")
raise HTTPException(status_code=500, detail=str(e))
if __name__ == "__main__":
import uvicorn
# Run the FastAPI application
uvicorn.run(app, host="0.0.0.0", port=8080)
|