Tonic commited on
Commit
4929bc6
·
verified ·
1 Parent(s): fa951d7

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +66 -26
app.py CHANGED
@@ -1,36 +1,76 @@
1
- from difflib import Differ
2
-
3
  import gradio as gr
4
 
5
- def diff_texts(text1, text2):
6
- d = Differ()
7
- return [
8
- (token[2:], token[0] if token[0] != " " else None)
9
- for token in d.compare(text1, text2)
10
- ]
11
 
12
- demo = gr.Interface(
13
- diff_texts,
14
- [
15
- gr.Textbox(
16
- label="Text 1",
17
- info="Initial text",
18
- lines=3,
19
- value="The quick brown fox jumped over the lazy dogs.",
20
- ),
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
21
  gr.Textbox(
22
- label="Text 2",
23
- info="Text to compare",
24
- lines=3,
25
- value="The fast brown fox jumps over lazy dogs.",
26
  ),
 
 
 
 
 
 
 
27
  ],
28
- gr.HighlightedText(
29
- label="Diff",
30
  combine_adjacent=True,
31
- show_legend=True,
32
- color_map={"+": "red", "-": "green"}),
 
 
33
  theme=gr.themes.Base()
34
  )
 
35
  if __name__ == "__main__":
36
- demo.launch()
 
1
+ import torch
2
+ from transformers import GPT2LMHeadModel, GPT2Tokenizer
3
  import gradio as gr
4
 
5
+ # Load pre-trained model and tokenizer
6
+ model_name = "PleIAs/OCRonos-Vintage"
7
+ model = GPT2LMHeadModel.from_pretrained(model_name)
8
+ tokenizer = GPT2Tokenizer.from_pretrained(model_name)
 
 
9
 
10
+ # Set the device to GPU if available, otherwise use CPU
11
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
12
+ model.to(device)
13
+
14
+ def historical_generation(prompt, max_new_tokens=600):
15
+ prompt = f"### Text ###\n{prompt}"
16
+ input_ids = tokenizer.encode(prompt, return_tensors="pt").to(device)
17
+
18
+ # Generate text
19
+ output = model.generate(input_ids,
20
+ max_new_tokens=max_new_tokens,
21
+ pad_token_id=tokenizer.eos_token_id,
22
+ top_k=50,
23
+ temperature=0.3,
24
+ top_p=0.95,
25
+ do_sample=True,
26
+ repetition_penalty=1.5)
27
+
28
+ # Decode the generated text
29
+ generated_text = tokenizer.decode(output[0], skip_special_tokens=True)
30
+
31
+ # Remove the prompt from the generated text
32
+ generated_text = generated_text.replace("### Text ###\n", "").strip()
33
+
34
+ # Tokenize the generated text
35
+ tokens = tokenizer.tokenize(generated_text)
36
+
37
+ # Create highlighted text output
38
+ highlighted_text = []
39
+ for token in tokens:
40
+ # Remove special tokens and get the token type
41
+ clean_token = token.replace("Ġ", "").replace("</w>", "")
42
+ token_type = tokenizer.convert_ids_to_tokens([tokenizer.convert_tokens_to_ids(token)])[0]
43
+
44
+ highlighted_text.append((clean_token, token_type))
45
+
46
+ return highlighted_text
47
+
48
+ # Create Gradio interface
49
+ iface = gr.Interface(
50
+ fn=historical_generation,
51
+ inputs=[
52
  gr.Textbox(
53
+ label="Prompt",
54
+ placeholder="Enter a prompt for historical text generation...",
55
+ lines=3
 
56
  ),
57
+ gr.Slider(
58
+ label="Max New Tokens",
59
+ minimum=50,
60
+ maximum=1000,
61
+ step=50,
62
+ value=600
63
+ )
64
  ],
65
+ outputs=gr.HighlightedText(
66
+ label="Generated Historical Text",
67
  combine_adjacent=True,
68
+ show_legend=True
69
+ ),
70
+ title="Historical Text Generation with OCRonos-Vintage",
71
+ description="Generate historical-style text using the OCRonos-Vintage model. The output shows token types as highlights.",
72
  theme=gr.themes.Base()
73
  )
74
+
75
  if __name__ == "__main__":
76
+ iface.launch()