Spaces:
Runtime error
Runtime error
# app.py | |
import gradio as gr | |
import torch | |
from model import GPTModel # Import your specific GPT model class | |
from transformers import PreTrainedTokenizerFast | |
# Load model and tokenizer once at startup | |
def load_model_n_tokenizer(): | |
model = GPTModel.from_pretrained("Aananda-giri/GPT2-Nepali") | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
model.to(device) | |
tokenizer = PreTrainedTokenizerFast.from_pretrained("Aananda-giri/NepaliBPE") | |
return model, tokenizer | |
# Initialize at startup | |
model, tokenizer = load_model_n_tokenizer() | |
model.eval() | |
def generate(prompt, max_new_tokens, top_k, temperature, repetition_penalty, penalize_len_below): | |
device = next(model.parameters()).device | |
with torch.no_grad(): | |
input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(device) | |
outputs = model.generate( | |
input_ids, | |
max_new_tokens=max_new_tokens, | |
top_k=top_k, | |
temperature=temperature, | |
repetition_penalty=repetition_penalty, | |
min_length=penalize_len_below, | |
pad_token_id=tokenizer.pad_token_id, | |
eos_token_id=tokenizer.eos_token_id, | |
) | |
return tokenizer.decode(outputs[0], skip_special_tokens=True) | |
# Create Gradio interface | |
interface = gr.Interface( | |
fn=generate, | |
inputs=[ | |
gr.Textbox(label="Prompt", placeholder="Enter Nepali text here..."), | |
gr.Slider(minimum=1, maximum=512, value=50, step=1, label="Max New Tokens"), | |
gr.Slider(minimum=1, maximum=100, value=3, step=1, label="Top K"), | |
gr.Slider(minimum=0.1, maximum=2.0, value=0.7, step=0.1, label="Temperature"), | |
gr.Slider(minimum=1.0, maximum=2.0, value=1.2, step=0.1, label="Repetition Penalty"), | |
gr.Slider(minimum=1, maximum=200, value=50, step=1, label="Minimum Length Penalty"), | |
], | |
outputs=gr.Textbox(label="Generated Text"), | |
title="Nepali GPT-2 Text Generator", | |
description="Enter Nepali text to generate content using the custom GPT-2 model." | |
) | |
interface.launch() |