import os from typing import Optional, Dict, Any from enum import Enum from fastapi import FastAPI, HTTPException, status from pathlib import Path import logging import sys from pydantic import BaseModel, Field import torch from transformers import AutoTokenizer, AutoModelForCausalLM import json # Define base model directory BASE_MODEL_DIR = "./models/" # Configure logging with fallback to stdout if file writing fails def setup_logging(): logger = logging.getLogger(__name__) logger.setLevel(logging.DEBUG) # Create formatter formatter = logging.Formatter( '%(asctime)s - %(name)s - %(levelname)s - %(message)s' ) # Always add stdout handler stdout_handler = logging.StreamHandler(sys.stdout) stdout_handler.setFormatter(formatter) logger.addHandler(stdout_handler) # Try to add file handler, but don't fail if we can't try: # First try logs directory in current working directory log_dir = os.path.join(os.getcwd(), 'logs') if not os.path.exists(log_dir): os.makedirs(log_dir, exist_ok=True) file_handler = logging.FileHandler(os.path.join(log_dir, 'poetry_generation.log')) file_handler.setFormatter(formatter) logger.addHandler(file_handler) except (OSError, PermissionError) as e: print(f"Warning: Could not create log file (using stdout only): {e}") return logger # Set up logging logger = setup_logging() class GenerateRequest(BaseModel): prompt: str = Field(..., min_length=1, max_length=500) max_length: Optional[int] = Field(default=100, ge=10, le=500) temperature: float = Field(default=0.9, ge=0.1, le=2.0) top_k: int = Field(default=50, ge=1, le=100) top_p: float = Field(default=0.95, ge=0.1, le=1.0) repetition_penalty: float = Field(default=1.2, ge=1.0, le=2.0) class ModelManager: def __init__(self): self.model = None self.tokenizer = None def initialize(self): """Initialize the model and tokenizer""" try: logger.info("Loading tokenizer...") # First, let's try to load the base GPT-2 tokenizer self.tokenizer = AutoTokenizer.from_pretrained("gpt2") # Now customize it with your vocabulary if needed vocab_path = os.path.join(BASE_MODEL_DIR, "vocab.json") if os.path.exists(vocab_path): try: with open(vocab_path, 'r', encoding='utf-8') as f: custom_vocab = json.load(f) self.tokenizer.vocab = custom_vocab self.tokenizer.ids_to_tokens = {v: k for k, v in custom_vocab.items()} except Exception as e: logger.warning(f"Could not load custom vocabulary: {str(e)}") logger.info("Loading model...") model_path = os.path.join(BASE_MODEL_DIR, "poeticagpt.pth") if not os.path.exists(model_path): logger.error(f"Model file not found at {model_path}") return False # Load the model weights self.model = AutoModelForCausalLM.from_pretrained(model_path, local_files_only=True) # Force model to CPU self.model.to('cpu') self.model.eval() logger.info("Model and tokenizer loaded successfully") return True except Exception as e: logger.error(f"Error initializing model: {str(e)}") logger.exception("Detailed traceback:") return False def generate(self, request: GenerateRequest) -> Dict[str, Any]: """Generate poetry based on the request parameters""" if self.model is None or self.tokenizer is None: raise HTTPException( status_code=status.HTTP_503_SERVICE_UNAVAILABLE, detail="Model or tokenizer not loaded" ) try: # Encode input inputs = self.tokenizer.encode(request.prompt, return_tensors='pt') attention_mask = torch.ones(inputs.shape, dtype=torch.long) # Generate with torch.no_grad(): outputs = self.model.generate( inputs, attention_mask=attention_mask, max_length=request.max_length, num_return_sequences=1, temperature=request.temperature, top_k=request.top_k, top_p=request.top_p, repetition_penalty=request.repetition_penalty, do_sample=True, pad_token_id=self.tokenizer.eos_token_id, ) generated_text = self.tokenizer.decode(outputs[0], skip_special_tokens=True) return { "generated_text": generated_text, "prompt": request.prompt, "parameters": { "max_length": request.max_length, "temperature": request.temperature, "top_k": request.top_k, "top_p": request.top_p, "repetition_penalty": request.repetition_penalty } } except Exception as e: logger.error(f"Error generating text: {str(e)}") raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=str(e) ) # Create FastAPI app and model manager app = FastAPI(title="Poetry Generation API") model_manager = ModelManager() @app.on_event("startup") async def startup(): """Initialize the model during startup""" if not model_manager.initialize(): logger.error("Failed to initialize model manager") # In production, we might want to continue running even if model fails to load # Instead of exiting, we'll just log the error # sys.exit(1) @app.get("/health") async def health_check(): """Health check endpoint""" return { "status": "healthy", "model_loaded": model_manager.model is not None, "tokenizer_loaded": model_manager.tokenizer is not None } @app.post("/generate") async def generate_text(request: GenerateRequest): """Generate poetry with parameters""" return model_manager.generate(request) @app.on_event("shutdown") async def shutdown(): """Cleanup on shutdown""" if model_manager.model is not None: del model_manager.model if model_manager.tokenizer is not None: del model_manager.tokenizer