Spaces:
Build error
Build error
import os | |
from fastapi import FastAPI, HTTPException, BackgroundTasks | |
from fastapi.middleware.cors import CORSMiddleware | |
from pydantic import BaseModel | |
import torch | |
from transformers import AutoTokenizer, AutoModelForCausalLM | |
import logging | |
from typing import List, Optional | |
from datasets import load_dataset | |
from transformers import TrainingArguments, Trainer, DataCollatorForLanguageModeling | |
import json | |
# Setup logging | |
logging.basicConfig(level=logging.INFO) | |
logger = logging.getLogger(__name__) | |
# Setup cache directory | |
os.makedirs("/app/cache", exist_ok=True) | |
os.environ['TRANSFORMERS_CACHE'] = "/app/cache" | |
# Pydantic models for request/response | |
class GenerateRequest(BaseModel): | |
text: str | |
max_length: Optional[int] = 512 | |
temperature: Optional[float] = 0.7 | |
num_return_sequences: Optional[int] = 1 | |
class GenerateResponse(BaseModel): | |
generated_text: List[str] | |
class HealthResponse(BaseModel): | |
status: str | |
model_loaded: bool | |
gpu_available: bool | |
device: str | |
class TrainRequest(BaseModel): | |
dataset_path: str | |
num_epochs: Optional[int] = 3 | |
batch_size: Optional[int] = 4 | |
learning_rate: Optional[float] = 2e-5 | |
class TrainResponse(BaseModel): | |
status: str | |
message: str | |
# Add training status tracking | |
class TrainingStatus: | |
def __init__(self): | |
self.is_training = False | |
self.current_epoch = 0 | |
self.current_loss = None | |
self.status = "idle" | |
training_status = TrainingStatus() | |
# Initialize FastAPI app | |
app = FastAPI( | |
title="Medical LLaMA API", | |
description="API for medical text generation using fine-tuned LLaMA model", | |
version="1.0.0", | |
docs_url="/docs", | |
redoc_url="/redoc" | |
) | |
# Add CORS middleware | |
app.add_middleware( | |
CORSMiddleware, | |
allow_origins=["*"], | |
allow_credentials=True, | |
allow_methods=["*"], | |
allow_headers=["*"], | |
) | |
# Global variables for model and tokenizer | |
model = None | |
tokenizer = None | |
async def root(): | |
""" | |
Root endpoint to check API health and model status | |
""" | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
return HealthResponse( | |
status="online", | |
model_loaded=model is not None, | |
gpu_available=torch.cuda.is_available(), | |
device=device | |
) | |
async def generate_text(request: GenerateRequest): | |
""" | |
Generate medical text based on input prompt | |
""" | |
try: | |
# Check if model is loaded | |
if model is None or tokenizer is None: | |
logger.error("Model or tokenizer not initialized") | |
raise HTTPException( | |
status_code=500, | |
detail="Model not loaded. Please check if model was initialized correctly." | |
) | |
logger.info(f"Generating text for input: {request.text[:50]}...") | |
# Log device information | |
device_info = f"Using device: {model.device}" | |
logger.info(device_info) | |
# Tokenize input | |
try: | |
inputs = tokenizer( | |
request.text, | |
return_tensors="pt", | |
padding=True, | |
truncation=True, | |
max_length=request.max_length | |
) | |
logger.info("Input tokenized successfully") | |
# Move inputs to correct device | |
inputs = {k: v.to(model.device) for k, v in inputs.items()} | |
except Exception as e: | |
logger.error(f"Tokenization error: {str(e)}") | |
raise HTTPException(status_code=500, detail=f"Tokenization failed: {str(e)}") | |
# Generate text | |
try: | |
with torch.no_grad(): | |
generated_ids = model.generate( | |
inputs.input_ids, | |
max_length=request.max_length, | |
num_return_sequences=request.num_return_sequences, | |
temperature=request.temperature, | |
pad_token_id=tokenizer.pad_token_id, | |
eos_token_id=tokenizer.eos_token_id, | |
) | |
logger.info("Text generated successfully") | |
except Exception as e: | |
logger.error(f"Generation error: {str(e)}") | |
raise HTTPException(status_code=500, detail=f"Text generation failed: {str(e)}") | |
# Decode generated text | |
try: | |
generated_texts = [ | |
tokenizer.decode(g, skip_special_tokens=True) | |
for g in generated_ids | |
] | |
logger.info("Text decoded successfully") | |
except Exception as e: | |
logger.error(f"Decoding error: {str(e)}") | |
raise HTTPException(status_code=500, detail=f"Text decoding failed: {str(e)}") | |
return GenerateResponse(generated_text=generated_texts) | |
except HTTPException as he: | |
raise he | |
except Exception as e: | |
logger.error(f"Unexpected error: {str(e)}") | |
raise HTTPException( | |
status_code=500, | |
detail=f"An unexpected error occurred: {str(e)}" | |
) | |
async def health_check(): | |
""" | |
Check the health status of the API and model | |
""" | |
return { | |
"status": "healthy", | |
"model_loaded": model is not None, | |
"gpu_available": torch.cuda.is_available(), | |
"device": "cuda" if torch.cuda.is_available() else "cpu" | |
} | |
async def startup_event(): | |
logger.info("Starting up application...") | |
try: | |
global tokenizer, model | |
tokenizer, model = init_model() | |
logger.info("Model loaded successfully") | |
except Exception as e: | |
logger.error(f"Failed to load model: {str(e)}") | |
async def train_model(request: TrainRequest, background_tasks: BackgroundTasks): | |
""" | |
Start model training with the specified dataset | |
Parameters: | |
- dataset_path: Path to the JSON dataset file | |
- num_epochs: Number of training epochs | |
- batch_size: Training batch size | |
- learning_rate: Learning rate for training | |
""" | |
if training_status.is_training: | |
raise HTTPException(status_code=400, detail="Training is already in progress") | |
try: | |
# Verify dataset exists | |
if not os.path.exists(request.dataset_path): | |
raise HTTPException(status_code=404, detail="Dataset file not found") | |
# Start training in background | |
background_tasks.add_task( | |
run_training, | |
request.dataset_path, | |
request.num_epochs, | |
request.batch_size, | |
request.learning_rate | |
) | |
return TrainResponse( | |
status="started", | |
message="Training started in background" | |
) | |
except Exception as e: | |
logger.error(f"Training setup error: {str(e)}") | |
raise HTTPException(status_code=500, detail=str(e)) | |
async def get_training_status(): | |
""" | |
Get current training status | |
""" | |
return { | |
"is_training": training_status.is_training, | |
"current_epoch": training_status.current_epoch, | |
"current_loss": training_status.current_loss, | |
"status": training_status.status | |
} | |
# Add training function | |
async def run_training(dataset_path: str, num_epochs: int, batch_size: int, learning_rate: float): | |
global model, tokenizer, training_status | |
try: | |
training_status.is_training = True | |
training_status.status = "loading_dataset" | |
# Load dataset | |
dataset = load_dataset("json", data_files=dataset_path) | |
training_status.status = "preprocessing" | |
# Preprocess function | |
def preprocess_function(examples): | |
return tokenizer( | |
examples["text"], | |
truncation=True, | |
padding="max_length", | |
max_length=512 | |
) | |
# Tokenize dataset | |
tokenized_dataset = dataset.map( | |
preprocess_function, | |
batched=True, | |
remove_columns=dataset["train"].column_names | |
) | |
training_status.status = "training" | |
# Training arguments | |
training_args = TrainingArguments( | |
output_dir=f"{model_output_path}/checkpoints", | |
per_device_train_batch_size=batch_size, | |
gradient_accumulation_steps=4, | |
num_train_epochs=num_epochs, | |
learning_rate=learning_rate, | |
fp16=True, | |
save_steps=500, | |
logging_steps=100, | |
) | |
# Initialize trainer | |
trainer = Trainer( | |
model=model, | |
args=training_args, | |
train_dataset=tokenized_dataset["train"], | |
data_collator=DataCollatorForLanguageModeling( | |
tokenizer=tokenizer, | |
mlm=False | |
), | |
) | |
# Training callback to update status | |
class TrainingCallback(trainer.callback_handler): | |
def on_epoch_begin(self, args, state, control, **kwargs): | |
training_status.current_epoch = state.epoch | |
def on_log(self, args, state, control, logs=None, **kwargs): | |
if logs: | |
training_status.current_loss = logs.get("loss", None) | |
trainer.add_callback(TrainingCallback) | |
# Start training | |
trainer.train() | |
# Save the model | |
training_status.status = "saving" | |
model.save_pretrained(model_output_path) | |
tokenizer.save_pretrained(model_output_path) | |
training_status.status = "completed" | |
logger.info("Training completed successfully") | |
except Exception as e: | |
training_status.status = f"failed: {str(e)}" | |
logger.error(f"Training error: {str(e)}") | |
raise | |
finally: | |
training_status.is_training = False | |
# Update model initialization | |
def init_model(): | |
try: | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
logger.info(f"Loading model on device: {device}") | |
model_name = "nvidia/Meta-Llama-3.2-3B-Instruct-ONNX-INT4" | |
# Load tokenizer | |
logger.info("Loading tokenizer...") | |
tokenizer = AutoTokenizer.from_pretrained( | |
model_name, | |
cache_dir="/app/cache", | |
trust_remote_code=True | |
) | |
# Add padding token if not present | |
if tokenizer.pad_token is None: | |
tokenizer.pad_token = tokenizer.eos_token | |
logger.info("Loading model...") | |
model = AutoModelForCausalLM.from_pretrained( | |
model_name, | |
torch_dtype=torch.float16 if device == "cuda" else torch.float32, | |
device_map="auto", | |
cache_dir="/app/cache", | |
trust_remote_code=True | |
) | |
logger.info(f"Model loaded successfully on {device}") | |
return tokenizer, model | |
except Exception as e: | |
logger.error(f"Model initialization error: {str(e)}") | |
raise | |
async def model_status(): | |
""" | |
Get detailed model status | |
""" | |
try: | |
model_info = { | |
"model_loaded": model is not None, | |
"tokenizer_loaded": tokenizer is not None, | |
"model_device": str(model.device) if model else None, | |
"gpu_available": torch.cuda.is_available(), | |
"cuda_device_count": torch.cuda.device_count() if torch.cuda.is_available() else 0, | |
"cuda_device_name": torch.cuda.get_device_name(0) if torch.cuda.is_available() else None, | |
"model_type": type(model).__name__ if model else None, | |
"tokenizer_type": type(tokenizer).__name__ if tokenizer else None, | |
} | |
if model is not None: | |
try: | |
# Test tokenizer | |
test_input = tokenizer("test", return_tensors="pt") | |
model_info["tokenizer_working"] = True | |
except Exception as e: | |
model_info["tokenizer_working"] = False | |
model_info["tokenizer_error"] = str(e) | |
try: | |
# Test model forward pass | |
with torch.no_grad(): | |
test_output = model.generate( | |
test_input.input_ids.to(model.device), | |
max_length=10 | |
) | |
model_info["model_working"] = True | |
except Exception as e: | |
model_info["model_working"] = False | |
model_info["model_error"] = str(e) | |
return model_info | |
except Exception as e: | |
logger.error(f"Error checking model status: {str(e)}") | |
return { | |
"error": str(e), | |
"model_loaded": model is not None, | |
"tokenizer_loaded": tokenizer is not None | |
} |