Spaces:
Sleeping
Sleeping
import torch | |
from transformers import GPT2LMHeadModel, GPT2Tokenizer | |
import gradio as gr | |
# Load pre-trained model and tokenizer | |
model_name = "PleIAs/OCRonos-Vintage" | |
model = GPT2LMHeadModel.from_pretrained(model_name) | |
tokenizer = GPT2Tokenizer.from_pretrained(model_name) | |
# Set the device to GPU if available, otherwise use CPU | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
model.to(device) | |
def historical_generation(prompt, max_new_tokens=600): | |
prompt = f"### Text ###\n{prompt}" | |
inputs = tokenizer(prompt, return_tensors="pt", padding=True, truncation=True) | |
input_ids = inputs["input_ids"].to(device) | |
attention_mask = inputs["attention_mask"].to(device) | |
# Generate text | |
output = model.generate( | |
input_ids, | |
attention_mask=attention_mask, | |
max_new_tokens=max_new_tokens, | |
pad_token_id=tokenizer.eos_token_id, | |
top_k=50, | |
temperature=0.3, | |
top_p=0.95, | |
do_sample=True, | |
repetition_penalty=1.5, | |
bos_token_id=tokenizer.bos_token_id, | |
eos_token_id=tokenizer.eos_token_id | |
) | |
# Decode the generated text | |
generated_text = tokenizer.decode(output[0], skip_special_tokens=True) | |
# Remove the prompt from the generated text | |
generated_text = generated_text.replace("### Text ###\n", "").strip() | |
# Tokenize the generated text | |
tokens = tokenizer.tokenize(generated_text) | |
# Create highlighted text output | |
highlighted_text = [] | |
for token in tokens: | |
# Remove special tokens and get the token type | |
clean_token = token.replace("Δ ", "").replace("</w>", "") | |
token_type = tokenizer.convert_ids_to_tokens([tokenizer.convert_tokens_to_ids(token)])[0] | |
highlighted_text.append((clean_token, token_type)) | |
return highlighted_text | |
# Create Gradio interface | |
iface = gr.Interface( | |
fn=historical_generation, | |
inputs=[ | |
gr.Textbox( | |
label="Prompt", | |
placeholder="Enter a prompt for historical text generation...", | |
lines=3 | |
), | |
gr.Slider( | |
label="Max New Tokens", | |
minimum=50, | |
maximum=1000, | |
step=50, | |
value=600 | |
) | |
], | |
outputs=gr.HighlightedText( | |
label="Generated Historical Text", | |
combine_adjacent=True, | |
show_legend=True | |
), | |
title="Historical Text Generation with OCRonos-Vintage", | |
description="Generate historical-style text using the OCRonos-Vintage model. The output shows token types as highlights.", | |
theme=gr.themes.Base() | |
) | |
if __name__ == "__main__": | |
iface.launch() |