Spaces:
Build error
Build error
| import os | |
| from fastapi import FastAPI, HTTPException, BackgroundTasks | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from pydantic import BaseModel | |
| import torch | |
| from transformers import AutoTokenizer, AutoModelForCausalLM | |
| import logging | |
| from typing import List, Optional | |
| from datasets import load_dataset | |
| from transformers import TrainingArguments, Trainer, DataCollatorForLanguageModeling | |
| import json | |
| # Setup logging | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger(__name__) | |
| # Setup cache directory | |
| os.makedirs("/app/cache", exist_ok=True) | |
| os.environ['TRANSFORMERS_CACHE'] = "/app/cache" | |
| # Pydantic models for request/response | |
| class GenerateRequest(BaseModel): | |
| text: str | |
| max_length: Optional[int] = 512 | |
| temperature: Optional[float] = 0.7 | |
| num_return_sequences: Optional[int] = 1 | |
| class GenerateResponse(BaseModel): | |
| generated_text: List[str] | |
| class HealthResponse(BaseModel): | |
| status: str | |
| model_loaded: bool | |
| gpu_available: bool | |
| device: str | |
| class TrainRequest(BaseModel): | |
| dataset_path: str | |
| num_epochs: Optional[int] = 3 | |
| batch_size: Optional[int] = 4 | |
| learning_rate: Optional[float] = 2e-5 | |
| class TrainResponse(BaseModel): | |
| status: str | |
| message: str | |
| # Add training status tracking | |
| class TrainingStatus: | |
| def __init__(self): | |
| self.is_training = False | |
| self.current_epoch = 0 | |
| self.current_loss = None | |
| self.status = "idle" | |
| training_status = TrainingStatus() | |
| # Initialize FastAPI app | |
| app = FastAPI( | |
| title="Medical LLaMA API", | |
| description="API for medical text generation using fine-tuned LLaMA model", | |
| version="1.0.0", | |
| docs_url="/docs", | |
| redoc_url="/redoc" | |
| ) | |
| # Add CORS middleware | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], | |
| allow_credentials=True, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| # Global variables for model and tokenizer | |
| model = None | |
| tokenizer = None | |
| async def root(): | |
| """ | |
| Root endpoint to check API health and model status | |
| """ | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| return HealthResponse( | |
| status="online", | |
| model_loaded=model is not None, | |
| gpu_available=torch.cuda.is_available(), | |
| device=device | |
| ) | |
| async def generate_text(request: GenerateRequest): | |
| """ | |
| Generate medical text based on input prompt | |
| """ | |
| try: | |
| # Check if model is loaded | |
| if model is None or tokenizer is None: | |
| logger.error("Model or tokenizer not initialized") | |
| raise HTTPException( | |
| status_code=500, | |
| detail="Model not loaded. Please check if model was initialized correctly." | |
| ) | |
| logger.info(f"Generating text for input: {request.text[:50]}...") | |
| # Log device information | |
| device_info = f"Using device: {model.device}" | |
| logger.info(device_info) | |
| # Tokenize input | |
| try: | |
| inputs = tokenizer( | |
| request.text, | |
| return_tensors="pt", | |
| padding=True, | |
| truncation=True, | |
| max_length=request.max_length | |
| ) | |
| logger.info("Input tokenized successfully") | |
| # Move inputs to correct device | |
| inputs = {k: v.to(model.device) for k, v in inputs.items()} | |
| except Exception as e: | |
| logger.error(f"Tokenization error: {str(e)}") | |
| raise HTTPException(status_code=500, detail=f"Tokenization failed: {str(e)}") | |
| # Generate text | |
| try: | |
| with torch.no_grad(): | |
| generated_ids = model.generate( | |
| inputs.input_ids, | |
| max_length=request.max_length, | |
| num_return_sequences=request.num_return_sequences, | |
| temperature=request.temperature, | |
| pad_token_id=tokenizer.pad_token_id, | |
| eos_token_id=tokenizer.eos_token_id, | |
| ) | |
| logger.info("Text generated successfully") | |
| except Exception as e: | |
| logger.error(f"Generation error: {str(e)}") | |
| raise HTTPException(status_code=500, detail=f"Text generation failed: {str(e)}") | |
| # Decode generated text | |
| try: | |
| generated_texts = [ | |
| tokenizer.decode(g, skip_special_tokens=True) | |
| for g in generated_ids | |
| ] | |
| logger.info("Text decoded successfully") | |
| except Exception as e: | |
| logger.error(f"Decoding error: {str(e)}") | |
| raise HTTPException(status_code=500, detail=f"Text decoding failed: {str(e)}") | |
| return GenerateResponse(generated_text=generated_texts) | |
| except HTTPException as he: | |
| raise he | |
| except Exception as e: | |
| logger.error(f"Unexpected error: {str(e)}") | |
| raise HTTPException( | |
| status_code=500, | |
| detail=f"An unexpected error occurred: {str(e)}" | |
| ) | |
| async def health_check(): | |
| """ | |
| Check the health status of the API and model | |
| """ | |
| return { | |
| "status": "healthy", | |
| "model_loaded": model is not None, | |
| "gpu_available": torch.cuda.is_available(), | |
| "device": "cuda" if torch.cuda.is_available() else "cpu" | |
| } | |
| async def startup_event(): | |
| logger.info("Starting up application...") | |
| try: | |
| global tokenizer, model | |
| tokenizer, model = init_model() | |
| logger.info("Model loaded successfully") | |
| except Exception as e: | |
| logger.error(f"Failed to load model: {str(e)}") | |
| async def train_model(request: TrainRequest, background_tasks: BackgroundTasks): | |
| """ | |
| Start model training with the specified dataset | |
| Parameters: | |
| - dataset_path: Path to the JSON dataset file | |
| - num_epochs: Number of training epochs | |
| - batch_size: Training batch size | |
| - learning_rate: Learning rate for training | |
| """ | |
| if training_status.is_training: | |
| raise HTTPException(status_code=400, detail="Training is already in progress") | |
| try: | |
| # Verify dataset exists | |
| if not os.path.exists(request.dataset_path): | |
| raise HTTPException(status_code=404, detail="Dataset file not found") | |
| # Start training in background | |
| background_tasks.add_task( | |
| run_training, | |
| request.dataset_path, | |
| request.num_epochs, | |
| request.batch_size, | |
| request.learning_rate | |
| ) | |
| return TrainResponse( | |
| status="started", | |
| message="Training started in background" | |
| ) | |
| except Exception as e: | |
| logger.error(f"Training setup error: {str(e)}") | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| async def get_training_status(): | |
| """ | |
| Get current training status | |
| """ | |
| return { | |
| "is_training": training_status.is_training, | |
| "current_epoch": training_status.current_epoch, | |
| "current_loss": training_status.current_loss, | |
| "status": training_status.status | |
| } | |
| # Add training function | |
| async def run_training(dataset_path: str, num_epochs: int, batch_size: int, learning_rate: float): | |
| global model, tokenizer, training_status | |
| try: | |
| training_status.is_training = True | |
| training_status.status = "loading_dataset" | |
| # Load dataset | |
| dataset = load_dataset("json", data_files=dataset_path) | |
| training_status.status = "preprocessing" | |
| # Preprocess function | |
| def preprocess_function(examples): | |
| return tokenizer( | |
| examples["text"], | |
| truncation=True, | |
| padding="max_length", | |
| max_length=512 | |
| ) | |
| # Tokenize dataset | |
| tokenized_dataset = dataset.map( | |
| preprocess_function, | |
| batched=True, | |
| remove_columns=dataset["train"].column_names | |
| ) | |
| training_status.status = "training" | |
| # Training arguments | |
| training_args = TrainingArguments( | |
| output_dir=f"{model_output_path}/checkpoints", | |
| per_device_train_batch_size=batch_size, | |
| gradient_accumulation_steps=4, | |
| num_train_epochs=num_epochs, | |
| learning_rate=learning_rate, | |
| fp16=True, | |
| save_steps=500, | |
| logging_steps=100, | |
| ) | |
| # Initialize trainer | |
| trainer = Trainer( | |
| model=model, | |
| args=training_args, | |
| train_dataset=tokenized_dataset["train"], | |
| data_collator=DataCollatorForLanguageModeling( | |
| tokenizer=tokenizer, | |
| mlm=False | |
| ), | |
| ) | |
| # Training callback to update status | |
| class TrainingCallback(trainer.callback_handler): | |
| def on_epoch_begin(self, args, state, control, **kwargs): | |
| training_status.current_epoch = state.epoch | |
| def on_log(self, args, state, control, logs=None, **kwargs): | |
| if logs: | |
| training_status.current_loss = logs.get("loss", None) | |
| trainer.add_callback(TrainingCallback) | |
| # Start training | |
| trainer.train() | |
| # Save the model | |
| training_status.status = "saving" | |
| model.save_pretrained(model_output_path) | |
| tokenizer.save_pretrained(model_output_path) | |
| training_status.status = "completed" | |
| logger.info("Training completed successfully") | |
| except Exception as e: | |
| training_status.status = f"failed: {str(e)}" | |
| logger.error(f"Training error: {str(e)}") | |
| raise | |
| finally: | |
| training_status.is_training = False | |
| # Update model initialization | |
| def init_model(): | |
| try: | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| logger.info(f"Loading model on device: {device}") | |
| model_name = "nvidia/Meta-Llama-3.2-3B-Instruct-ONNX-INT4" | |
| # Load tokenizer | |
| logger.info("Loading tokenizer...") | |
| tokenizer = AutoTokenizer.from_pretrained( | |
| model_name, | |
| cache_dir="/app/cache", | |
| trust_remote_code=True | |
| ) | |
| # Add padding token if not present | |
| if tokenizer.pad_token is None: | |
| tokenizer.pad_token = tokenizer.eos_token | |
| logger.info("Loading model...") | |
| model = AutoModelForCausalLM.from_pretrained( | |
| model_name, | |
| torch_dtype=torch.float16 if device == "cuda" else torch.float32, | |
| device_map="auto", | |
| cache_dir="/app/cache", | |
| trust_remote_code=True | |
| ) | |
| logger.info(f"Model loaded successfully on {device}") | |
| return tokenizer, model | |
| except Exception as e: | |
| logger.error(f"Model initialization error: {str(e)}") | |
| raise | |
| async def model_status(): | |
| """ | |
| Get detailed model status | |
| """ | |
| try: | |
| model_info = { | |
| "model_loaded": model is not None, | |
| "tokenizer_loaded": tokenizer is not None, | |
| "model_device": str(model.device) if model else None, | |
| "gpu_available": torch.cuda.is_available(), | |
| "cuda_device_count": torch.cuda.device_count() if torch.cuda.is_available() else 0, | |
| "cuda_device_name": torch.cuda.get_device_name(0) if torch.cuda.is_available() else None, | |
| "model_type": type(model).__name__ if model else None, | |
| "tokenizer_type": type(tokenizer).__name__ if tokenizer else None, | |
| } | |
| if model is not None: | |
| try: | |
| # Test tokenizer | |
| test_input = tokenizer("test", return_tensors="pt") | |
| model_info["tokenizer_working"] = True | |
| except Exception as e: | |
| model_info["tokenizer_working"] = False | |
| model_info["tokenizer_error"] = str(e) | |
| try: | |
| # Test model forward pass | |
| with torch.no_grad(): | |
| test_output = model.generate( | |
| test_input.input_ids.to(model.device), | |
| max_length=10 | |
| ) | |
| model_info["model_working"] = True | |
| except Exception as e: | |
| model_info["model_working"] = False | |
| model_info["model_error"] = str(e) | |
| return model_info | |
| except Exception as e: | |
| logger.error(f"Error checking model status: {str(e)}") | |
| return { | |
| "error": str(e), | |
| "model_loaded": model is not None, | |
| "tokenizer_loaded": tokenizer is not None | |
| } |