import torch from transformers import GPT2LMHeadModel, GPT2Tokenizer import gradio as gr # Load pre-trained model and tokenizer model_name = "PleIAs/OCRonos-Vintage" model = GPT2LMHeadModel.from_pretrained(model_name) tokenizer = GPT2Tokenizer.from_pretrained(model_name) # Set the device to GPU if available, otherwise use CPU device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model.to(device) def historical_generation(prompt, max_new_tokens=600): prompt = f"### Text ###\n{prompt}" inputs = tokenizer(prompt, return_tensors="pt", padding=True, truncation=True) input_ids = inputs["input_ids"].to(device) attention_mask = inputs["attention_mask"].to(device) # Generate text output = model.generate( input_ids, attention_mask=attention_mask, max_new_tokens=max_new_tokens, pad_token_id=tokenizer.eos_token_id, top_k=50, temperature=0.3, top_p=0.95, do_sample=True, repetition_penalty=1.5, bos_token_id=tokenizer.bos_token_id, eos_token_id=tokenizer.eos_token_id ) # Decode the generated text generated_text = tokenizer.decode(output[0], skip_special_tokens=True) # Remove the prompt from the generated text generated_text = generated_text.replace("### Text ###\n", "").strip() # Tokenize the generated text tokens = tokenizer.tokenize(generated_text) # Create highlighted text output highlighted_text = [] for token in tokens: # Remove special tokens and get the token type clean_token = token.replace("Ġ", "").replace("", "") token_type = tokenizer.convert_ids_to_tokens([tokenizer.convert_tokens_to_ids(token)])[0] highlighted_text.append((clean_token, token_type)) return highlighted_text # Create Gradio interface iface = gr.Interface( fn=historical_generation, inputs=[ gr.Textbox( label="Prompt", placeholder="Enter a prompt for historical text generation...", lines=3 ), gr.Slider( label="Max New Tokens", minimum=50, maximum=1000, step=50, value=600 ) ], outputs=gr.HighlightedText( label="Generated Historical Text", combine_adjacent=True, show_legend=True ), title="Historical Text Generation with OCRonos-Vintage", description="Generate historical-style text using the OCRonos-Vintage model. The output shows token types as highlights.", theme=gr.themes.Base() ) if __name__ == "__main__": iface.launch()