from fastapi import FastAPI, HTTPException from fastapi.responses import StreamingResponse from pydantic import BaseModel from transformers import AutoTokenizer, AutoModelForCausalLM import torch import traceback # --------------------------- # Setup # --------------------------- app = FastAPI(title="GPT-Neo Story API (HF Free-tier Optimized)") model_name = "EleutherAI/gpt-neo-125M" # Limit threads for HF Free-tier CPU stability torch.set_num_threads(1) # --------------------------- # Request Model # --------------------------- class StoryRequest(BaseModel): prompt: str # --------------------------- # Tokenizer & Model # --------------------------- print("Loading tokenizer...") tokenizer = AutoTokenizer.from_pretrained(model_name) if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token print("Loading model...") try: model = AutoModelForCausalLM.from_pretrained( model_name, dtype=torch.float32, # fix deprecation warning device_map="cpu", low_cpu_mem_usage=True ) model.eval() print("GPT-Neo-125M loaded successfully on CPU!") except Exception as e: traceback.print_exc() raise RuntimeError(f"Failed to load model: {str(e)}") # --------------------------- # Streaming story generation (HF-safe) # --------------------------- def stream_generate(prompt_text: str, max_new_tokens: int = 80, chunk_size: int = 10): """ Generate story in small chunks for HF Free-tier. """ try: prompt = f"""You are a creative storyteller. Write a short story (1 paragraph), surprising and imaginative with twists. Prompt: {prompt_text}""" inputs = tokenizer( prompt, return_tensors="pt", truncation=True, max_length=512, # smaller for faster generation padding=True ) inputs = {k: v.to(model.device) for k, v in inputs.items()} generated_ids = inputs['input_ids'] attention_mask = inputs['attention_mask'] with torch.inference_mode(): for _ in range(0, max_new_tokens, chunk_size): outputs = model.generate( input_ids=generated_ids, attention_mask=attention_mask, max_new_tokens=chunk_size, do_sample=True, temperature=0.8, # slightly faster sampling top_p=0.9, top_k=40, repetition_penalty=1.1, pad_token_id=tokenizer.pad_token_id, use_cache=False ) # yield only new tokens new_tokens = outputs[:, generated_ids.shape[1]:] generated_ids = outputs # update attention mask attention_mask = torch.cat([attention_mask, torch.ones_like(new_tokens)], dim=1) text_chunk = tokenizer.decode(new_tokens[0], skip_special_tokens=True) yield text_chunk del inputs, outputs except Exception as e: traceback.print_exc() yield f"\nError generating story: {str(e)}" # --------------------------- # Endpoint # --------------------------- @app.post("/api/generate-story") async def generate_story(req: StoryRequest): prompt_text = req.prompt.strip() if not prompt_text: raise HTTPException(status_code=400, detail="Prompt must not be empty") return StreamingResponse(stream_generate(prompt_text), media_type="text/plain") # --------------------------- # Run server # --------------------------- if __name__ == "__main__": import uvicorn uvicorn.run(app, host="0.0.0.0", port=7860)