Spaces:
Sleeping
Sleeping
import torch | |
import os | |
from fastapi import FastAPI, Request | |
from pydantic import BaseModel | |
from transformers import PreTrainedModel, AutoConfig | |
from huggingface_hub import hf_hub_download | |
import tiktoken | |
from model import GPT, GPTConfig | |
from fastapi.templating import Jinja2Templates | |
from fastapi.middleware.cors import CORSMiddleware | |
from fastapi.responses import HTMLResponse | |
from fastapi.staticfiles import StaticFiles | |
from pathlib import Path | |
import tempfile | |
# Get the absolute path to the templates directory | |
TEMPLATES_DIR = os.path.join(os.path.dirname(__file__), "templates") | |
MODEL_ID = "sagargurujula/text-generator" | |
# Initialize FastAPI | |
app = FastAPI(title="GPT Text Generator") | |
# Templates with absolute path | |
templates = Jinja2Templates(directory=TEMPLATES_DIR) | |
# Add CORS middleware | |
app.add_middleware( | |
CORSMiddleware, | |
allow_origins=["*"], | |
allow_credentials=True, | |
allow_methods=["*"], | |
allow_headers=["*"], | |
) | |
# Set device | |
device = 'cuda' if torch.cuda.is_available() else 'cpu' | |
# Use system's temporary directory | |
cache_dir = Path(tempfile.gettempdir()) / "model_cache" | |
os.environ['TRANSFORMERS_CACHE'] = str(cache_dir) | |
os.environ['HF_HOME'] = str(cache_dir) | |
# Load model from Hugging Face Hub | |
def load_model(): | |
try: | |
# Download the model file from HF Hub with authentication | |
model_path = hf_hub_download( | |
repo_id=MODEL_ID, | |
filename="best_model.pth", | |
cache_dir=cache_dir, | |
token=os.environ.get('HF_TOKEN') # Get token from environment variable | |
) | |
# Initialize our custom GPT model | |
model = GPT(GPTConfig()) | |
# Load the state dict | |
checkpoint = torch.load(model_path, map_location=device, weights_only=True) | |
model.load_state_dict(checkpoint['model_state_dict']) | |
model.to(device) | |
model.eval() | |
return model | |
except Exception as e: | |
print(f"Error loading model: {e}") | |
raise | |
# Load the model | |
model = load_model() | |
# Define the request body | |
class TextInput(BaseModel): | |
text: str | |
async def generate_text(input: TextInput): | |
# Prepare input tensor | |
enc = tiktoken.get_encoding('gpt2') | |
input_ids = torch.tensor([enc.encode(input.text)]).to(device) | |
# Generate multiple tokens | |
generated_tokens = [] | |
num_tokens_to_generate = 50 # Generate 20 new tokens | |
with torch.no_grad(): | |
current_ids = input_ids | |
for _ in range(num_tokens_to_generate): | |
# Get model predictions | |
logits, _ = model(current_ids) | |
next_token = logits[0, -1, :].argmax().item() | |
generated_tokens.append(next_token) | |
# Add the new token to our current sequence | |
current_ids = torch.cat([current_ids, torch.tensor([[next_token]]).to(device)], dim=1) | |
# Decode all generated tokens | |
generated_text = enc.decode(generated_tokens) | |
# Return both input and generated text | |
return { | |
"input_text": input.text, | |
"generated_text": generated_text | |
} | |
# Modify the root route to serve the template | |
async def home(request: Request): | |
return templates.TemplateResponse( | |
"index.html", | |
{"request": request, "title": "GPT Text Generator"} | |
) | |
if __name__ == "__main__": | |
import uvicorn | |
uvicorn.run(app, host="127.0.0.1", port=8080) | |
# To run the app, use the command: uvicorn app:app --reload |