Chaitanya Sagar Gurujula
ve
25acc67
import torch
import os
from fastapi import FastAPI, Request
from pydantic import BaseModel
from transformers import PreTrainedModel, AutoConfig
from huggingface_hub import hf_hub_download
import tiktoken
from model import GPT, GPTConfig
from fastapi.templating import Jinja2Templates
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import HTMLResponse
from fastapi.staticfiles import StaticFiles
from pathlib import Path
import tempfile
# Get the absolute path to the templates directory
TEMPLATES_DIR = os.path.join(os.path.dirname(__file__), "templates")
MODEL_ID = "sagargurujula/text-generator"
# Initialize FastAPI
app = FastAPI(title="GPT Text Generator")
# Templates with absolute path
templates = Jinja2Templates(directory=TEMPLATES_DIR)
# Add CORS middleware
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# Set device
device = 'cuda' if torch.cuda.is_available() else 'cpu'
# Use system's temporary directory
cache_dir = Path(tempfile.gettempdir()) / "model_cache"
os.environ['TRANSFORMERS_CACHE'] = str(cache_dir)
os.environ['HF_HOME'] = str(cache_dir)
# Load model from Hugging Face Hub
def load_model():
try:
# Download the model file from HF Hub with authentication
model_path = hf_hub_download(
repo_id=MODEL_ID,
filename="best_model.pth",
cache_dir=cache_dir,
token=os.environ.get('HF_TOKEN') # Get token from environment variable
)
# Initialize our custom GPT model
model = GPT(GPTConfig())
# Load the state dict
checkpoint = torch.load(model_path, map_location=device, weights_only=True)
model.load_state_dict(checkpoint['model_state_dict'])
model.to(device)
model.eval()
return model
except Exception as e:
print(f"Error loading model: {e}")
raise
# Load the model
model = load_model()
# Define the request body
class TextInput(BaseModel):
text: str
@app.post("/generate/")
async def generate_text(input: TextInput):
# Prepare input tensor
enc = tiktoken.get_encoding('gpt2')
input_ids = torch.tensor([enc.encode(input.text)]).to(device)
# Generate multiple tokens
generated_tokens = []
num_tokens_to_generate = 50 # Generate 20 new tokens
with torch.no_grad():
current_ids = input_ids
for _ in range(num_tokens_to_generate):
# Get model predictions
logits, _ = model(current_ids)
next_token = logits[0, -1, :].argmax().item()
generated_tokens.append(next_token)
# Add the new token to our current sequence
current_ids = torch.cat([current_ids, torch.tensor([[next_token]]).to(device)], dim=1)
# Decode all generated tokens
generated_text = enc.decode(generated_tokens)
# Return both input and generated text
return {
"input_text": input.text,
"generated_text": generated_text
}
# Modify the root route to serve the template
@app.get("/", response_class=HTMLResponse)
async def home(request: Request):
return templates.TemplateResponse(
"index.html",
{"request": request, "title": "GPT Text Generator"}
)
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="127.0.0.1", port=8080)
# To run the app, use the command: uvicorn app:app --reload