File size: 7,677 Bytes
88208e2
51a0302
 
 
88208e2
a71106f
88208e2
51a0302
88208e2
 
 
 
 
 
 
 
 
 
ac3f3ed
 
 
88208e2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5811c7f
88208e2
 
 
51a0302
 
 
 
 
 
 
 
88208e2
51a0302
88208e2
51a0302
88208e2
51a0302
 
 
 
 
88208e2
51a0302
 
 
 
 
 
 
88208e2
 
51a0302
 
88208e2
 
51a0302
 
 
 
 
88208e2
51a0302
88208e2
 
51a0302
 
 
88208e2
51a0302
88208e2
 
51a0302
 
 
 
88208e2
51a0302
 
 
88208e2
 
 
 
 
51a0302
 
 
 
 
 
88208e2
 
51a0302
88208e2
51a0302
88208e2
 
51a0302
88208e2
51a0302
 
88208e2
 
51a0302
 
 
88208e2
 
51a0302
 
 
88208e2
51a0302
 
 
88208e2
 
 
51a0302
 
 
 
88208e2
 
51a0302
88208e2
51a0302
88208e2
 
51a0302
 
88208e2
51a0302
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
# app.py
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from typing import Optional
import torch
import uvicorn
from transformers import pipeline
import os
from contextlib import asynccontextmanager # Import this!
import sys # Import sys for sys.exit()

# Optional: For gated models like Llama 3 from Meta, uncomment and configure HF_TOKEN
# from huggingface_hub import login

# --- Global variable to store the pipeline ---
generator = None
# Choose a model appropriate for free tier (e.g., 7B-8B parameters)
# For DeepSeek, DeepSeek-V2-Lite-Base (7B) might be loadable, but DeepSeek-V3 is too big.
MODEL_NAME = "brendon-ai/gemma3-dolly-finetuned"

#"openai-community/gpt2" # Recommended for free tier

# --- Lifespan Event Handler ---
@asynccontextmanager
async def lifespan(app: FastAPI):
    """
    Handles startup and shutdown events for the FastAPI application.
    Loads the model on startup and can optionally clean up on shutdown.
    """
    global generator
    try:
        # --- Optional: Login to Hugging Face Hub for gated models ---
        # If you are using a gated model (e.g., meta-llama/Llama-3-8B-Instruct),
        # uncomment the following lines and ensure HF_TOKEN is set as a Space Secret.
        # hf_token = os.getenv("HF_TOKEN")
        # if hf_token:
        #     login(token=hf_token)
        #     print("Logged into Hugging Face Hub.")
        # else:
        #     print("HF_TOKEN not found. Make sure it's set as a Space Secret if using a gated model.")

        # --- Startup Code: Load the model ---
        if torch.cuda.is_available():
            print(f"CUDA is available! Using {torch.cuda.get_device_name(0)}")
            device = 0 # Use GPU
            # For larger models, use device_map="auto" and torch_dtype
            # device_map = "auto"
            # torch_dtype = torch.bfloat16 # or torch.float16 for GPUs that support it
        else:
            print("CUDA not available, using CPU. Inference will be very slow for this model size.")
            device = -1 # Use CPU
            # device_map = None
            # torch_dtype = torch.float32 # Default for CPU

        print(f"Attempting to load model '{MODEL_NAME}' on device: {'cuda' if device == 0 else 'cpu'}")

        # The pipeline automatically handles AutoModel and AutoTokenizer.
        # For better memory management with larger models, directly load with model_kwargs:
        generator = pipeline(
            'text-generation',
            model=MODEL_NAME,
            device=device,
            # Pass your HF token to the model loading for gated models
            # token=os.getenv("HF_TOKEN"), # Uncomment if using a gated model
            # For 7B models on 16GB GPU, float16 is usually enough, but bfloat16 is better if supported
            torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
            # For more fine-grained control and auto device mapping for multiple GPUs:
            # model_kwargs={"device_map": "auto", "torch_dtype": torch.float16}
        )
        print("Model loaded successfully!")
        
        # 'yield' signifies that the startup code has completed, and the application
        # can now start processing requests.
        yield 
        
    except Exception as e:
        print(f"CRITICAL ERROR: Failed to load model during startup: {e}")
        # Exit with a non-zero code to indicate failure if model loading fails
        sys.exit(1)
            
    finally:
        # --- Shutdown Code (Optional): Clean up resources ---
        print("Application shutting down. Any cleanup can go here.")


# --- Initialize FastAPI application with the lifespan handler ---
app = FastAPI(lifespan=lifespan, # Use the lifespan context manager
    title="Text Generation API",
    description="A simple text generation API using Hugging Face transformers",
    version="1.0.0"
)

# Request model
class TextGenerationRequest(BaseModel):
    prompt: str
    max_new_tokens: Optional[int] = 250 # Changed from max_length for better control
    num_return_sequences: Optional[int] = 1
    temperature: Optional[float] = 0.7 # Recommend lower temp for more coherent output
    do_sample: Optional[bool] = True
    top_p: Optional[float] = 0.9 # Added top_p for more control

# Response model
class TextGenerationResponse(BaseModel):
    generated_text: str
    prompt: str
    model_name: str

@app.get("/")
async def root():
    return {
        "message": "Text Generation API", 
        "status": "running",
        "endpoints": {
            "generate_post": "/generate", # Renamed for clarity
            "generate_get": "/generate_simple", # Renamed for clarity
            "health": "/health",
            "docs": "/docs"
        },
        "current_model": MODEL_NAME
    }

@app.get("/health")
async def health_check():
    return {
        "status": "healthy" if generator else "unhealthy",
        "model_loaded": generator is not None,
        "cuda_available": torch.cuda.is_available(),
        "model_name": MODEL_NAME
    }

@app.post("/generate", response_model=TextGenerationResponse)
async def generate_text_post(request: TextGenerationRequest):
    if generator is None:
        raise HTTPException(status_code=503, detail="Model not loaded yet. Service unavailable.")
            
    try:
        # Generate text
        result = generator(
            request.prompt,
            max_new_tokens=request.max_new_tokens, # Use max_new_tokens
            num_return_sequences=request.num_return_sequences,
            temperature=request.temperature,
            do_sample=request.do_sample,
            top_p=request.top_p, # Pass top_p
            pad_token_id=generator.tokenizer.eos_token_id,
            eos_token_id=generator.tokenizer.eos_token_id,
            # Add stop sequences relevant to your instruction-tuned model format
            # stop_sequences=["\nUser:", "\n###", "\n\n"] 
        )
        
        generated_text = result[0]['generated_text']
        
        return TextGenerationResponse(
            generated_text=generated_text,
            prompt=request.prompt,
            model_name=MODEL_NAME
        )
            
    except Exception as e:
        print(f"Generation failed: {str(e)}")
        raise HTTPException(status_code=500, detail=f"Generation failed: {str(e)}. Check Space logs for details.")

@app.get("/generate_simple") # Changed endpoint name to avoid conflict with POST
async def generate_text_get(
    prompt: str,
    max_new_tokens: int = 250, # Changed from max_length
    temperature: float = 0.7
):
    """GET endpoint for simple text generation"""
    if generator is None:
        raise HTTPException(status_code=503, detail="Model not loaded yet. Service unavailable.")
            
    try:
        result = generator(
            prompt,
            max_new_tokens=max_new_tokens,
            num_return_sequences=1,
            temperature=temperature,
            do_sample=True,
            top_p=0.9, # Default top_p for simple GET
            pad_token_id=generator.tokenizer.eos_token_id,
            eos_token_id=generator.tokenizer.eos_token_id,
        )
        
        return {
            "generated_text": result[0]['generated_text'],
            "prompt": prompt,
            "model_name": MODEL_NAME
        }
            
    except Exception as e:
        print(f"Generation failed: {str(e)}")
        raise HTTPException(status_code=500, detail=f"Generation failed: {str(e)}. Check Space logs for details.")

if __name__ == "__main__":
    port = int(os.environ.get("PORT", 7860)) # Hugging Face Spaces uses port 7860
    uvicorn.run(app, host="0.0.0.0", port=port)