Spaces:
Sleeping
Sleeping
| import os | |
| import shutil | |
| import tempfile | |
| from fastapi import APIRouter, HTTPException, UploadFile, File, Form, Query | |
| from fastapi.encoders import jsonable_encoder | |
| from bson import ObjectId | |
| from models import InitializeBotResponse, NewChatResponse, QueryRequest, QueryResponse | |
| from trainer_manager import get_trainer | |
| from config import CUSTOM_PROMPT | |
| from prompt_templates import PromptTemplates | |
| router = APIRouter() | |
| trainer = get_trainer() | |
| def initialize_bot(prompt_type: str = Query(None)): | |
| """ | |
| Initializes a new bot and returns its bot_id. | |
| Accepts an optional 'prompt_type' query parameter (provided by the frontend). | |
| """ | |
| try: | |
| bot_id = trainer.initialize_bot_id() | |
| # Optionally, you might want to store the prompt_type with the bot record. | |
| return InitializeBotResponse(bot_id=bot_id) | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| async def upload_document(bot_id: str = Form(...), file: UploadFile = File(...)): | |
| """ | |
| Saves the uploaded file temporarily and adds it to the specified bot's knowledge base. | |
| """ | |
| try: | |
| # Save the file to a temporary location. | |
| with tempfile.NamedTemporaryFile(delete=False, suffix=os.path.splitext(file.filename)[1]) as tmp: | |
| contents = await file.read() | |
| tmp.write(contents) | |
| tmp_path = tmp.name | |
| # Add the document using the temporary file path to the specified bot. | |
| trainer.add_document_from_path(tmp_path, bot_id) | |
| # Remove the temporary file. | |
| os.remove(tmp_path) | |
| return {"message": "Document uploaded and added successfully."} | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| def create_bot(bot_id: str, prompt_type: str = Query(None)): | |
| """ | |
| Finalizes the creation (build) of the bot identified by bot_id. | |
| Uses the provided (or default) prompt_type to determine the custom prompt template. | |
| If no prompt_type is provided, it defaults to "quiz_solving". | |
| """ | |
| try: | |
| if prompt_type is None: | |
| prompt_type = "quiz_solving" | |
| # Determine the appropriate prompt template. | |
| if prompt_type == "university": | |
| prompt_template = PromptTemplates.get_university_chatbot_prompt() | |
| elif prompt_type == "quiz_solving": | |
| prompt_template = PromptTemplates.get_quiz_solving_prompt() | |
| elif prompt_type == "assignment_solving": | |
| prompt_template = PromptTemplates.get_assignment_solving_prompt() | |
| elif prompt_type == "paper_solving": | |
| prompt_template = PromptTemplates.get_paper_solving_prompt() | |
| elif prompt_type == "quiz_creation": | |
| prompt_template = PromptTemplates.get_quiz_creation_prompt() | |
| elif prompt_type == "assignment_creation": | |
| prompt_template = PromptTemplates.get_assignment_creation_prompt() | |
| elif prompt_type == "paper_creation": | |
| prompt_template = PromptTemplates.get_paper_creation_prompt() | |
| elif prompt_type == "check_quiz": | |
| prompt_template = PromptTemplates.get_check_quiz_prompt() | |
| elif prompt_type == "check_assignment": | |
| prompt_template = PromptTemplates.get_check_assignment_prompt() | |
| elif prompt_type == "check_paper": | |
| prompt_template = PromptTemplates.get_check_paper_prompt() | |
| else: | |
| prompt_template = PromptTemplates.get_quiz_solving_prompt() | |
| # Create (build) the bot using the specified bot_id and prompt template. | |
| trainer.create_bot(bot_id, prompt_template) | |
| return {"message": f"Bot {bot_id} created successfully."} | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| def new_chat(bot_id: str): | |
| """ | |
| Creates a new chat session for the specified bot. | |
| """ | |
| try: | |
| chat_id = trainer.new_chat(bot_id) | |
| return NewChatResponse(chat_id=chat_id) | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| def send_query(query_request: QueryRequest): | |
| """ | |
| Processes a query and returns the bot's response along with any web sources. | |
| The request must include bot_id, chat_id, and the query text. | |
| """ | |
| try: | |
| response, web_sources = trainer.get_response( | |
| query_request.query, query_request.bot_id, query_request.chat_id | |
| ) | |
| return QueryResponse(response=response, web_sources=web_sources) | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| def list_chats(bot_id: str): | |
| """ | |
| Returns a list of previous chat sessions for the specified bot. | |
| """ | |
| try: | |
| chats = trainer.list_chats(bot_id) | |
| return chats | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| def chat_history(chat_id: str, bot_id: str = Query(None)): | |
| """ | |
| Returns the chat history for a given chat session. | |
| The bot_id can be provided as a query parameter (if needed). | |
| ObjectId instances in the history are converted to strings. | |
| """ | |
| try: | |
| history = trainer.get_chat_by_id(chat_id=chat_id) | |
| return jsonable_encoder(history, custom_encoder={ObjectId: str}) | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=str(e)) | |