Spaces:
Sleeping
Sleeping
File size: 5,719 Bytes
4929bc6 fa951d7 65bec20 fa951d7 4929bc6 fa951d7 8c1d821 4929bc6 d153be5 3582cae 65bec20 be6c757 4929bc6 8c1d821 a103a7f 4929bc6 a103a7f be6c757 a103a7f be6c757 a103a7f 4929bc6 2ca0200 4929bc6 65bec20 4929bc6 2f70cad be6c757 4929bc6 2ca0200 2f70cad 2ca0200 3988e91 2f70cad be6c757 65bec20 2f70cad 65bec20 3582cae 3988e91 2f70cad be6c757 2f70cad 3988e91 2f70cad 3988e91 2f70cad 3988e91 3582cae 2f70cad 3988e91 3582cae 2f70cad 3582cae 3988e91 be6c757 2f70cad 3582cae 3988e91 3582cae |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 |
import torch
from transformers import GPT2LMHeadModel, GPT2Tokenizer
import gradio as gr
import os
import spacy
from spacy import displacy
model_name = "PleIAs/OCRonos-Vintage"
model = GPT2LMHeadModel.from_pretrained(model_name)
tokenizer = GPT2Tokenizer.from_pretrained(model_name)
tokenizer.pad_token = tokenizer.eos_token
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
os.system('python -m spacy download en_core_web_sm')
nlp = spacy.load("en_core_web_sm")
def historical_generation(prompt, max_new_tokens=600, top_k=50, temperature=0.7, top_p=0.95, repetition_penalty=1.0):
prompt = f"### Text ###\n{prompt}"
inputs = tokenizer(prompt, return_tensors="pt", padding=True, truncation=True, max_length=1024)
input_ids = inputs["input_ids"].to(device)
attention_mask = inputs["attention_mask"].to(device)
output = model.generate(
input_ids,
attention_mask=attention_mask,
max_new_tokens=max_new_tokens,
pad_token_id=tokenizer.eos_token_id,
top_k=top_k,
temperature=temperature,
top_p=top_p,
do_sample=True,
repetition_penalty=repetition_penalty,
bos_token_id=tokenizer.bos_token_id,
eos_token_id=tokenizer.eos_token_id
)
generated_text = tokenizer.decode(output[0], skip_special_tokens=True)
if "### Correction ###" in generated_text:
generated_text = generated_text.split("### Correction ###")[1].strip()
tokens = tokenizer.tokenize(generated_text)
highlighted_text = []
for token in tokens:
clean_token = token.replace("Ġ", "")
token_type = tokenizer.convert_ids_to_tokens([tokenizer.convert_tokens_to_ids(token)])[0].replace("Ġ", "")
highlighted_text.append((clean_token, token_type))
return highlighted_text, generated_text
def text_analysis(text):
doc = nlp(text)
html = displacy.render(doc, style="dep", page=True)
html = (
"<div style='max-width:100%; max-height:360px; overflow:auto'>"
+ html
+ "</div>"
)
pos_count = {
"char_count": len(text),
"token_count": len(list(doc)),
}
pos_tokens = [(token.text, token.pos_) for token in doc]
return pos_tokens, pos_count, html
def generate_dependency_parse(generated_text):
tokens_generated, pos_count_generated, html_generated = text_analysis(generated_text)
return html_generated
def generate_dependency_parse(generated_text):
tokens_generated, pos_count_generated, html_generated = text_analysis(generated_text)
return html_generated
def full_interface(prompt, max_new_tokens, top_k, temperature, top_p, repetition_penalty):
generated_highlight, generated_text = historical_generation(
prompt, max_new_tokens, top_k, temperature, top_p, repetition_penalty
)
tokens_input, pos_count_input, html_input = text_analysis(prompt)
return generated_text, generated_highlight, pos_count_input, html_input, gr.update(visible=True), generated_text, gr.update(visible=False), gr.update(visible=True)
def reset_interface():
return gr.update(visible=True), gr.update(visible=False), gr.update(visible=False)
import gradio as gr
with gr.Blocks(theme=gr.themes.Base()) as iface:
gr.Markdown("""
# Historical Text Generator with Dependency Parse
This app generates historical-style text using the OCRonos-Vintage model.
You can customize the generation parameters using the sliders and visualize the tokenized output and dependency parse.
""")
prompt = gr.Textbox(label="Add a passage in the style of historical texts", placeholder="Hi there my name is Tonic and I ride my bicycle along the river Seine:", lines=3)
# Sliders for model parameters
max_new_tokens = gr.Slider(label="Max New Tokens", minimum=50, maximum=1000, step=10, value=140)
top_k = gr.Slider(label="Top-k Sampling", minimum=1, maximum=100, step=0.05, value=50)
temperature = gr.Slider(label="Temperature", minimum=0.1, maximum=1.5, step=0.05, value=0.3)
top_p = gr.Slider(label="Top-p (Nucleus Sampling)", minimum=0.1, maximum=1.0, step=0.005, value=0.95)
repetition_penalty = gr.Slider(label="Repetition Penalty", minimum=0.5, maximum=2.0, step=0.05, value=1.0)
# Output components
generated_text_output = gr.Textbox(label="🎅🏻⌚OCRonos-Vintage", readonly=True)
highlighted_text = gr.HighlightedText(label="🎅🏻⌚Tokenized", combine_adjacent=True, show_legend=True)
tokenizer_info = gr.JSON(label="📉Tokenizer Info (Input Text)")
dependency_parse_input = gr.HTML(label="👁️Visualization")
# Hidden button and final output for dependency parse visualization
send_button = gr.Button(value="👁️Visualize Generated Text", visible=False)
dependency_parse_generated = gr.HTML(label="👁️Visualization" (Generated Text)")
# Reset button, hidden initially
reset_button = gr.Button(value="♻️Start Again", visible=False)
# Main interface logic: when clicked, "Generate" button hides itself and shows the reset button
generate_button = gr.Button(value="🎅🏻⌚Generate Historical Text")
generate_button.click(
full_interface,
inputs=[prompt, max_new_tokens, top_k, temperature, top_p, repetition_penalty],
outputs=[generated_text_output, highlighted_text, tokenizer_info, dependency_parse_input, send_button, dependency_parse_generated, generate_button, reset_button]
)
# Reset button logic: hide itself and re-show the "Generate" button
reset_button.click(
reset_interface,
inputs=None,
outputs=[generate_button, send_button, reset_button]
)
iface.launch()
|