fullstuckdev
fixing
e7ceaff
raw
history blame
1.91 kB
import os
from fastapi import FastAPI, HTTPException, BackgroundTasks
from fastapi.middleware.cors import CORSMiddleware
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
import logging
# Setup logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# Setup cache directory
os.makedirs("/app/cache", exist_ok=True)
os.environ['TRANSFORMERS_CACHE'] = "/app/cache"
app = FastAPI(title="Medical LLaMA API")
# Add CORS middleware
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# Check GPU availability
def check_gpu():
if torch.cuda.is_available():
logger.info(f"GPU available: {torch.cuda.get_device_name(0)}")
return True
logger.warning("No GPU available, using CPU")
return False
# Initialize model with proper device
def init_model():
try:
device = "cuda" if check_gpu() else "cpu"
model_path = os.getenv("MODEL_PATH", "./model/medical_llama_3b")
logger.info(f"Loading model from {model_path}")
tokenizer = AutoTokenizer.from_pretrained(model_path, cache_dir="/app/cache")
model = AutoModelForCausalLM.from_pretrained(
model_path,
torch_dtype=torch.float16 if device == "cuda" else torch.float32,
device_map="auto",
cache_dir="/app/cache"
)
return tokenizer, model
except Exception as e:
logger.error(f"Error loading model: {str(e)}")
raise
# Rest of your existing code...
@app.on_event("startup")
async def startup_event():
logger.info("Starting up application...")
try:
global tokenizer, model
tokenizer, model = init_model()
logger.info("Model loaded successfully")
except Exception as e:
logger.error(f"Failed to load model: {str(e)}")