Spaces:
Sleeping
Sleeping
| from fastapi import FastAPI | |
| from transformers import AutoTokenizer, AutoModelForCausalLM | |
| from peft import PeftModel | |
| import torch | |
| from pydantic import BaseModel | |
| import os | |
| os.environ["HF_HOME"] = "/app/cache" | |
| os.environ["TRANSFORMERS_CACHE"] = "/app/cache" | |
| app = FastAPI() | |
| base_model = AutoModelForCausalLM.from_pretrained("gpt2") | |
| tokenizer = AutoTokenizer.from_pretrained("n-sudheer/ns-lora-gpt2-demo") | |
| model = PeftModel.from_pretrained(base_model, "n-sudheer/ns-lora-gpt2-demo") | |
| model.eval() | |
| class GenerationRequest(BaseModel): | |
| prompt: str | |
| max_length: int = 50 | |
| def generate_text(request: GenerationRequest): | |
| if '@NS' in request.prompt: | |
| return {"generated_text": "NS is God!!!"} | |
| inputs = tokenizer(request.prompt, return_tensors="pt") | |
| outputs = model.generate(**inputs, max_length=request.max_length) | |
| text = tokenizer.decode(outputs[0], skip_special_tokens=True, do_sample=True, top_k=20) | |
| print("Input Prompt:", request.prompt) | |
| print("Generated Text:", text) | |
| return {"generated_text": text} | |