import os from fastapi import FastAPI, HTTPException, BackgroundTasks from fastapi.middleware.cors import CORSMiddleware from pydantic import BaseModel import torch from transformers import AutoTokenizer, AutoModelForCausalLM import logging from typing import List, Optional from datasets import load_dataset from transformers import TrainingArguments, Trainer, DataCollatorForLanguageModeling import json # Setup logging logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) # Setup cache directory os.makedirs("/app/cache", exist_ok=True) os.environ['TRANSFORMERS_CACHE'] = "/app/cache" # Pydantic models for request/response class GenerateRequest(BaseModel): text: str max_length: Optional[int] = 512 temperature: Optional[float] = 0.7 num_return_sequences: Optional[int] = 1 class GenerateResponse(BaseModel): generated_text: List[str] class HealthResponse(BaseModel): status: str model_loaded: bool gpu_available: bool device: str class TrainRequest(BaseModel): dataset_path: str num_epochs: Optional[int] = 3 batch_size: Optional[int] = 4 learning_rate: Optional[float] = 2e-5 class TrainResponse(BaseModel): status: str message: str # Add training status tracking class TrainingStatus: def __init__(self): self.is_training = False self.current_epoch = 0 self.current_loss = None self.status = "idle" training_status = TrainingStatus() # Initialize FastAPI app app = FastAPI( title="Medical LLaMA API", description="API for medical text generation using fine-tuned LLaMA model", version="1.0.0", docs_url="/docs", redoc_url="/redoc" ) # Add CORS middleware app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) # Global variables for model and tokenizer model = None tokenizer = None @app.get("/", response_model=HealthResponse, tags=["Health"]) async def root(): """ Root endpoint to check API health and model status """ device = "cuda" if torch.cuda.is_available() else "cpu" return HealthResponse( status="online", model_loaded=model is not None, gpu_available=torch.cuda.is_available(), device=device ) @app.post("/generate", response_model=GenerateResponse, tags=["Generation"]) async def generate_text(request: GenerateRequest): """ Generate medical text based on input prompt """ try: # Check if model is loaded if model is None or tokenizer is None: logger.error("Model or tokenizer not initialized") raise HTTPException( status_code=500, detail="Model not loaded. Please check if model was initialized correctly." ) logger.info(f"Generating text for input: {request.text[:50]}...") # Log device information device_info = f"Using device: {model.device}" logger.info(device_info) # Tokenize input try: inputs = tokenizer( request.text, return_tensors="pt", padding=True, truncation=True, max_length=request.max_length ) logger.info("Input tokenized successfully") # Move inputs to correct device inputs = {k: v.to(model.device) for k, v in inputs.items()} except Exception as e: logger.error(f"Tokenization error: {str(e)}") raise HTTPException(status_code=500, detail=f"Tokenization failed: {str(e)}") # Generate text try: with torch.no_grad(): generated_ids = model.generate( inputs.input_ids, max_length=request.max_length, num_return_sequences=request.num_return_sequences, temperature=request.temperature, pad_token_id=tokenizer.pad_token_id, eos_token_id=tokenizer.eos_token_id, ) logger.info("Text generated successfully") except Exception as e: logger.error(f"Generation error: {str(e)}") raise HTTPException(status_code=500, detail=f"Text generation failed: {str(e)}") # Decode generated text try: generated_texts = [ tokenizer.decode(g, skip_special_tokens=True) for g in generated_ids ] logger.info("Text decoded successfully") except Exception as e: logger.error(f"Decoding error: {str(e)}") raise HTTPException(status_code=500, detail=f"Text decoding failed: {str(e)}") return GenerateResponse(generated_text=generated_texts) except HTTPException as he: raise he except Exception as e: logger.error(f"Unexpected error: {str(e)}") raise HTTPException( status_code=500, detail=f"An unexpected error occurred: {str(e)}" ) @app.get("/health", tags=["Health"]) async def health_check(): """ Check the health status of the API and model """ return { "status": "healthy", "model_loaded": model is not None, "gpu_available": torch.cuda.is_available(), "device": "cuda" if torch.cuda.is_available() else "cpu" } @app.on_event("startup") async def startup_event(): logger.info("Starting up application...") try: global tokenizer, model tokenizer, model = init_model() logger.info("Model loaded successfully") except Exception as e: logger.error(f"Failed to load model: {str(e)}") @app.post("/train", response_model=TrainResponse, tags=["Training"]) async def train_model(request: TrainRequest, background_tasks: BackgroundTasks): """ Start model training with the specified dataset Parameters: - dataset_path: Path to the JSON dataset file - num_epochs: Number of training epochs - batch_size: Training batch size - learning_rate: Learning rate for training """ if training_status.is_training: raise HTTPException(status_code=400, detail="Training is already in progress") try: # Verify dataset exists if not os.path.exists(request.dataset_path): raise HTTPException(status_code=404, detail="Dataset file not found") # Start training in background background_tasks.add_task( run_training, request.dataset_path, request.num_epochs, request.batch_size, request.learning_rate ) return TrainResponse( status="started", message="Training started in background" ) except Exception as e: logger.error(f"Training setup error: {str(e)}") raise HTTPException(status_code=500, detail=str(e)) @app.get("/train/status", tags=["Training"]) async def get_training_status(): """ Get current training status """ return { "is_training": training_status.is_training, "current_epoch": training_status.current_epoch, "current_loss": training_status.current_loss, "status": training_status.status } # Add training function async def run_training(dataset_path: str, num_epochs: int, batch_size: int, learning_rate: float): global model, tokenizer, training_status try: training_status.is_training = True training_status.status = "loading_dataset" # Load dataset dataset = load_dataset("json", data_files=dataset_path) training_status.status = "preprocessing" # Preprocess function def preprocess_function(examples): return tokenizer( examples["text"], truncation=True, padding="max_length", max_length=512 ) # Tokenize dataset tokenized_dataset = dataset.map( preprocess_function, batched=True, remove_columns=dataset["train"].column_names ) training_status.status = "training" # Training arguments training_args = TrainingArguments( output_dir=f"{model_output_path}/checkpoints", per_device_train_batch_size=batch_size, gradient_accumulation_steps=4, num_train_epochs=num_epochs, learning_rate=learning_rate, fp16=True, save_steps=500, logging_steps=100, ) # Initialize trainer trainer = Trainer( model=model, args=training_args, train_dataset=tokenized_dataset["train"], data_collator=DataCollatorForLanguageModeling( tokenizer=tokenizer, mlm=False ), ) # Training callback to update status class TrainingCallback(trainer.callback_handler): def on_epoch_begin(self, args, state, control, **kwargs): training_status.current_epoch = state.epoch def on_log(self, args, state, control, logs=None, **kwargs): if logs: training_status.current_loss = logs.get("loss", None) trainer.add_callback(TrainingCallback) # Start training trainer.train() # Save the model training_status.status = "saving" model.save_pretrained(model_output_path) tokenizer.save_pretrained(model_output_path) training_status.status = "completed" logger.info("Training completed successfully") except Exception as e: training_status.status = f"failed: {str(e)}" logger.error(f"Training error: {str(e)}") raise finally: training_status.is_training = False # Update model initialization def init_model(): try: device = "cuda" if torch.cuda.is_available() else "cpu" logger.info(f"Loading model on device: {device}") model_name = "nvidia/Meta-Llama-3.2-3B-Instruct-ONNX-INT4" # Load tokenizer logger.info("Loading tokenizer...") tokenizer = AutoTokenizer.from_pretrained( model_name, cache_dir="/app/cache", trust_remote_code=True ) # Add padding token if not present if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token logger.info("Loading model...") model = AutoModelForCausalLM.from_pretrained( model_name, torch_dtype=torch.float16 if device == "cuda" else torch.float32, device_map="auto", cache_dir="/app/cache", trust_remote_code=True ) logger.info(f"Model loaded successfully on {device}") return tokenizer, model except Exception as e: logger.error(f"Model initialization error: {str(e)}") raise @app.get("/model-status", tags=["Health"]) async def model_status(): """ Get detailed model status """ try: model_info = { "model_loaded": model is not None, "tokenizer_loaded": tokenizer is not None, "model_device": str(model.device) if model else None, "gpu_available": torch.cuda.is_available(), "cuda_device_count": torch.cuda.device_count() if torch.cuda.is_available() else 0, "cuda_device_name": torch.cuda.get_device_name(0) if torch.cuda.is_available() else None, "model_type": type(model).__name__ if model else None, "tokenizer_type": type(tokenizer).__name__ if tokenizer else None, } if model is not None: try: # Test tokenizer test_input = tokenizer("test", return_tensors="pt") model_info["tokenizer_working"] = True except Exception as e: model_info["tokenizer_working"] = False model_info["tokenizer_error"] = str(e) try: # Test model forward pass with torch.no_grad(): test_output = model.generate( test_input.input_ids.to(model.device), max_length=10 ) model_info["model_working"] = True except Exception as e: model_info["model_working"] = False model_info["model_error"] = str(e) return model_info except Exception as e: logger.error(f"Error checking model status: {str(e)}") return { "error": str(e), "model_loaded": model is not None, "tokenizer_loaded": tokenizer is not None }