Spaces:
Runtime error
Runtime error
File size: 7,677 Bytes
88208e2 51a0302 88208e2 a71106f 88208e2 51a0302 88208e2 ac3f3ed 88208e2 5811c7f 88208e2 51a0302 88208e2 51a0302 88208e2 51a0302 88208e2 51a0302 88208e2 51a0302 88208e2 51a0302 88208e2 51a0302 88208e2 51a0302 88208e2 51a0302 88208e2 51a0302 88208e2 51a0302 88208e2 51a0302 88208e2 51a0302 88208e2 51a0302 88208e2 51a0302 88208e2 51a0302 88208e2 51a0302 88208e2 51a0302 88208e2 51a0302 88208e2 51a0302 88208e2 51a0302 88208e2 51a0302 88208e2 51a0302 88208e2 51a0302 88208e2 51a0302 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 |
# app.py
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from typing import Optional
import torch
import uvicorn
from transformers import pipeline
import os
from contextlib import asynccontextmanager # Import this!
import sys # Import sys for sys.exit()
# Optional: For gated models like Llama 3 from Meta, uncomment and configure HF_TOKEN
# from huggingface_hub import login
# --- Global variable to store the pipeline ---
generator = None
# Choose a model appropriate for free tier (e.g., 7B-8B parameters)
# For DeepSeek, DeepSeek-V2-Lite-Base (7B) might be loadable, but DeepSeek-V3 is too big.
MODEL_NAME = "brendon-ai/gemma3-dolly-finetuned"
#"openai-community/gpt2" # Recommended for free tier
# --- Lifespan Event Handler ---
@asynccontextmanager
async def lifespan(app: FastAPI):
"""
Handles startup and shutdown events for the FastAPI application.
Loads the model on startup and can optionally clean up on shutdown.
"""
global generator
try:
# --- Optional: Login to Hugging Face Hub for gated models ---
# If you are using a gated model (e.g., meta-llama/Llama-3-8B-Instruct),
# uncomment the following lines and ensure HF_TOKEN is set as a Space Secret.
# hf_token = os.getenv("HF_TOKEN")
# if hf_token:
# login(token=hf_token)
# print("Logged into Hugging Face Hub.")
# else:
# print("HF_TOKEN not found. Make sure it's set as a Space Secret if using a gated model.")
# --- Startup Code: Load the model ---
if torch.cuda.is_available():
print(f"CUDA is available! Using {torch.cuda.get_device_name(0)}")
device = 0 # Use GPU
# For larger models, use device_map="auto" and torch_dtype
# device_map = "auto"
# torch_dtype = torch.bfloat16 # or torch.float16 for GPUs that support it
else:
print("CUDA not available, using CPU. Inference will be very slow for this model size.")
device = -1 # Use CPU
# device_map = None
# torch_dtype = torch.float32 # Default for CPU
print(f"Attempting to load model '{MODEL_NAME}' on device: {'cuda' if device == 0 else 'cpu'}")
# The pipeline automatically handles AutoModel and AutoTokenizer.
# For better memory management with larger models, directly load with model_kwargs:
generator = pipeline(
'text-generation',
model=MODEL_NAME,
device=device,
# Pass your HF token to the model loading for gated models
# token=os.getenv("HF_TOKEN"), # Uncomment if using a gated model
# For 7B models on 16GB GPU, float16 is usually enough, but bfloat16 is better if supported
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
# For more fine-grained control and auto device mapping for multiple GPUs:
# model_kwargs={"device_map": "auto", "torch_dtype": torch.float16}
)
print("Model loaded successfully!")
# 'yield' signifies that the startup code has completed, and the application
# can now start processing requests.
yield
except Exception as e:
print(f"CRITICAL ERROR: Failed to load model during startup: {e}")
# Exit with a non-zero code to indicate failure if model loading fails
sys.exit(1)
finally:
# --- Shutdown Code (Optional): Clean up resources ---
print("Application shutting down. Any cleanup can go here.")
# --- Initialize FastAPI application with the lifespan handler ---
app = FastAPI(lifespan=lifespan, # Use the lifespan context manager
title="Text Generation API",
description="A simple text generation API using Hugging Face transformers",
version="1.0.0"
)
# Request model
class TextGenerationRequest(BaseModel):
prompt: str
max_new_tokens: Optional[int] = 250 # Changed from max_length for better control
num_return_sequences: Optional[int] = 1
temperature: Optional[float] = 0.7 # Recommend lower temp for more coherent output
do_sample: Optional[bool] = True
top_p: Optional[float] = 0.9 # Added top_p for more control
# Response model
class TextGenerationResponse(BaseModel):
generated_text: str
prompt: str
model_name: str
@app.get("/")
async def root():
return {
"message": "Text Generation API",
"status": "running",
"endpoints": {
"generate_post": "/generate", # Renamed for clarity
"generate_get": "/generate_simple", # Renamed for clarity
"health": "/health",
"docs": "/docs"
},
"current_model": MODEL_NAME
}
@app.get("/health")
async def health_check():
return {
"status": "healthy" if generator else "unhealthy",
"model_loaded": generator is not None,
"cuda_available": torch.cuda.is_available(),
"model_name": MODEL_NAME
}
@app.post("/generate", response_model=TextGenerationResponse)
async def generate_text_post(request: TextGenerationRequest):
if generator is None:
raise HTTPException(status_code=503, detail="Model not loaded yet. Service unavailable.")
try:
# Generate text
result = generator(
request.prompt,
max_new_tokens=request.max_new_tokens, # Use max_new_tokens
num_return_sequences=request.num_return_sequences,
temperature=request.temperature,
do_sample=request.do_sample,
top_p=request.top_p, # Pass top_p
pad_token_id=generator.tokenizer.eos_token_id,
eos_token_id=generator.tokenizer.eos_token_id,
# Add stop sequences relevant to your instruction-tuned model format
# stop_sequences=["\nUser:", "\n###", "\n\n"]
)
generated_text = result[0]['generated_text']
return TextGenerationResponse(
generated_text=generated_text,
prompt=request.prompt,
model_name=MODEL_NAME
)
except Exception as e:
print(f"Generation failed: {str(e)}")
raise HTTPException(status_code=500, detail=f"Generation failed: {str(e)}. Check Space logs for details.")
@app.get("/generate_simple") # Changed endpoint name to avoid conflict with POST
async def generate_text_get(
prompt: str,
max_new_tokens: int = 250, # Changed from max_length
temperature: float = 0.7
):
"""GET endpoint for simple text generation"""
if generator is None:
raise HTTPException(status_code=503, detail="Model not loaded yet. Service unavailable.")
try:
result = generator(
prompt,
max_new_tokens=max_new_tokens,
num_return_sequences=1,
temperature=temperature,
do_sample=True,
top_p=0.9, # Default top_p for simple GET
pad_token_id=generator.tokenizer.eos_token_id,
eos_token_id=generator.tokenizer.eos_token_id,
)
return {
"generated_text": result[0]['generated_text'],
"prompt": prompt,
"model_name": MODEL_NAME
}
except Exception as e:
print(f"Generation failed: {str(e)}")
raise HTTPException(status_code=500, detail=f"Generation failed: {str(e)}. Check Space logs for details.")
if __name__ == "__main__":
port = int(os.environ.get("PORT", 7860)) # Hugging Face Spaces uses port 7860
uvicorn.run(app, host="0.0.0.0", port=port) |