File size: 5,719 Bytes
4929bc6
 
fa951d7
65bec20
 
 
fa951d7
4929bc6
 
 
fa951d7
8c1d821
 
4929bc6
 
 
d153be5
3582cae
65bec20
be6c757
4929bc6
8c1d821
a103a7f
 
4929bc6
a103a7f
 
 
 
 
be6c757
 
 
a103a7f
be6c757
a103a7f
 
 
4929bc6
 
2ca0200
 
 
4929bc6
 
65bec20
4929bc6
 
2f70cad
be6c757
4929bc6
2ca0200
2f70cad
2ca0200
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3988e91
 
 
 
2f70cad
 
 
 
be6c757
 
 
 
65bec20
 
2f70cad
65bec20
3582cae
 
3988e91
2f70cad
be6c757
2f70cad
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3988e91
 
2f70cad
 
 
 
3988e91
 
2f70cad
 
3988e91
3582cae
2f70cad
3988e91
3582cae
2f70cad
3582cae
3988e91
be6c757
2f70cad
3582cae
 
 
 
 
 
 
3988e91
 
3582cae
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
import torch
from transformers import GPT2LMHeadModel, GPT2Tokenizer
import gradio as gr
import os
import spacy
from spacy import displacy

model_name = "PleIAs/OCRonos-Vintage"
model = GPT2LMHeadModel.from_pretrained(model_name)
tokenizer = GPT2Tokenizer.from_pretrained(model_name)

tokenizer.pad_token = tokenizer.eos_token

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

os.system('python -m spacy download en_core_web_sm')
nlp = spacy.load("en_core_web_sm")

def historical_generation(prompt, max_new_tokens=600, top_k=50, temperature=0.7, top_p=0.95, repetition_penalty=1.0):
    prompt = f"### Text ###\n{prompt}"
    inputs = tokenizer(prompt, return_tensors="pt", padding=True, truncation=True, max_length=1024)
    input_ids = inputs["input_ids"].to(device)
    attention_mask = inputs["attention_mask"].to(device)

    output = model.generate(
        input_ids,
        attention_mask=attention_mask,
        max_new_tokens=max_new_tokens,
        pad_token_id=tokenizer.eos_token_id,
        top_k=top_k,
        temperature=temperature,
        top_p=top_p,
        do_sample=True,
        repetition_penalty=repetition_penalty,
        bos_token_id=tokenizer.bos_token_id,
        eos_token_id=tokenizer.eos_token_id
    )

    generated_text = tokenizer.decode(output[0], skip_special_tokens=True)

    if "### Correction ###" in generated_text:
        generated_text = generated_text.split("### Correction ###")[1].strip()

    tokens = tokenizer.tokenize(generated_text)

    highlighted_text = []
    for token in tokens:
        clean_token = token.replace("Ġ", "") 
        token_type = tokenizer.convert_ids_to_tokens([tokenizer.convert_tokens_to_ids(token)])[0].replace("Ġ", "")
        highlighted_text.append((clean_token, token_type))

    return highlighted_text, generated_text  

def text_analysis(text):
    doc = nlp(text)
    html = displacy.render(doc, style="dep", page=True)
    html = (
        "<div style='max-width:100%; max-height:360px; overflow:auto'>"
        + html
        + "</div>"
    )
    pos_count = {
        "char_count": len(text),
        "token_count": len(list(doc)),
    }
    pos_tokens = [(token.text, token.pos_) for token in doc]

    return pos_tokens, pos_count, html

def generate_dependency_parse(generated_text):
    tokens_generated, pos_count_generated, html_generated = text_analysis(generated_text)
    return html_generated

def generate_dependency_parse(generated_text):
    tokens_generated, pos_count_generated, html_generated = text_analysis(generated_text)
    return html_generated

def full_interface(prompt, max_new_tokens, top_k, temperature, top_p, repetition_penalty):
    generated_highlight, generated_text = historical_generation(
        prompt, max_new_tokens, top_k, temperature, top_p, repetition_penalty
    )

    tokens_input, pos_count_input, html_input = text_analysis(prompt)
    return generated_text, generated_highlight, pos_count_input, html_input, gr.update(visible=True), generated_text, gr.update(visible=False), gr.update(visible=True)

def reset_interface():
    return gr.update(visible=True), gr.update(visible=False), gr.update(visible=False)

import gradio as gr

with gr.Blocks(theme=gr.themes.Base()) as iface:  

    gr.Markdown("""
    # Historical Text Generator with Dependency Parse
    This app generates historical-style text using the OCRonos-Vintage model. 
    You can customize the generation parameters using the sliders and visualize the tokenized output and dependency parse.
    """)

    prompt = gr.Textbox(label="Add a passage in the style of historical texts", placeholder="Hi there my name is Tonic and I ride my bicycle along the river Seine:", lines=3)

    # Sliders for model parameters
    max_new_tokens = gr.Slider(label="Max New Tokens", minimum=50, maximum=1000, step=10, value=140)
    top_k = gr.Slider(label="Top-k Sampling", minimum=1, maximum=100, step=0.05, value=50)
    temperature = gr.Slider(label="Temperature", minimum=0.1, maximum=1.5, step=0.05, value=0.3)
    top_p = gr.Slider(label="Top-p (Nucleus Sampling)", minimum=0.1, maximum=1.0, step=0.005, value=0.95)
    repetition_penalty = gr.Slider(label="Repetition Penalty", minimum=0.5, maximum=2.0, step=0.05, value=1.0)

    # Output components
    generated_text_output = gr.Textbox(label="🎅🏻⌚OCRonos-Vintage", readonly=True)
    highlighted_text = gr.HighlightedText(label="🎅🏻⌚Tokenized", combine_adjacent=True, show_legend=True)
    tokenizer_info = gr.JSON(label="📉Tokenizer Info (Input Text)")
    dependency_parse_input = gr.HTML(label="👁️Visualization")
    
    # Hidden button and final output for dependency parse visualization
    send_button = gr.Button(value="👁️Visualize Generated Text", visible=False)
    dependency_parse_generated = gr.HTML(label="👁️Visualization" (Generated Text)")

    # Reset button, hidden initially
    reset_button = gr.Button(value="♻️Start Again", visible=False)

    # Main interface logic: when clicked, "Generate" button hides itself and shows the reset button
    generate_button = gr.Button(value="🎅🏻⌚Generate Historical Text")
    generate_button.click(
        full_interface, 
        inputs=[prompt, max_new_tokens, top_k, temperature, top_p, repetition_penalty], 
        outputs=[generated_text_output, highlighted_text, tokenizer_info, dependency_parse_input, send_button, dependency_parse_generated, generate_button, reset_button]
    )

    # Reset button logic: hide itself and re-show the "Generate" button
    reset_button.click(
        reset_interface, 
        inputs=None, 
        outputs=[generate_button, send_button, reset_button]
    )

iface.launch()