|
from transformers import GPT2LMHeadModel, GPT2Tokenizer |
|
import torch |
|
import gradio as gr |
|
|
|
|
|
model_name_or_path = model_name_or_path = "C:/Users/faiza/Downloads/fine_tuned_model" |
|
|
|
model = GPT2LMHeadModel.from_pretrained(model_name_or_path) |
|
tokenizer = GPT2Tokenizer.from_pretrained(model_name_or_path) |
|
|
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
model.to(device) |
|
|
|
|
|
def generate_text(seed_text, max_length=100, temperature=1.0, num_return_sequences=1): |
|
|
|
input_ids = tokenizer.encode(seed_text, return_tensors='pt').to(device) |
|
|
|
|
|
attention_mask = torch.ones(input_ids.shape, device=device) |
|
|
|
|
|
with torch.no_grad(): |
|
output = model.generate( |
|
input_ids, |
|
max_length=max_length, |
|
temperature=temperature, |
|
num_return_sequences=num_return_sequences, |
|
do_sample=True, |
|
top_k=50, |
|
top_p=0.95, |
|
pad_token_id=tokenizer.eos_token_id, |
|
attention_mask=attention_mask |
|
) |
|
|
|
|
|
generated_texts = [] |
|
for i in range(num_return_sequences): |
|
generated_text = tokenizer.decode(output[i], skip_special_tokens=True) |
|
generated_texts.append(generated_text) |
|
|
|
return generated_texts |
|
|
|
|
|
def predict(seed_text, max_length, temperature, num_return_sequences): |
|
generated_texts = generate_text(seed_text, max_length, temperature, num_return_sequences) |
|
return "\n\n".join(generated_texts) |
|
|
|
|
|
interface = gr.Interface( |
|
fn=predict, |
|
inputs=[ |
|
gr.Textbox(lines=2, placeholder="Enter seed text here...", label="Seed Text"), |
|
gr.Slider(minimum=50, maximum=500, value=50, step=1, label="Max Length"), |
|
gr.Slider(minimum=0.1, maximum=1.5, value=1.0, step=0.1, label="Temperature"), |
|
gr.Slider(minimum=1, maximum=5, value=1, step=1, label="Number of Return Sequences") |
|
], |
|
outputs=gr.Textbox(), |
|
title="GPT-2 Text Generation", |
|
description="Enter some text and see the generated output based on the fine-tuned GPT-2 model." |
|
) |
|
|
|
|
|
if __name__ == "__main__": |
|
interface.launch() |
|
|