Spaces:
Running
Running
Pclanglais
commited on
Commit
•
f16bf84
1
Parent(s):
ba05a34
Update app.py
Browse files
app.py
CHANGED
@@ -4,6 +4,7 @@ from transformers import GPT2LMHeadModel, GPT2Tokenizer
|
|
4 |
import torch
|
5 |
import gradio as gr
|
6 |
import difflib
|
|
|
7 |
from concurrent.futures import ThreadPoolExecutor
|
8 |
import os
|
9 |
|
@@ -56,23 +57,18 @@ def split_text(text, max_tokens=400):
|
|
56 |
|
57 |
return chunks
|
58 |
|
|
|
59 |
def ocr_correction(prompt, max_new_tokens=600, num_threads=os.cpu_count()):
|
60 |
prompt = f"""### Text ###\n{prompt}\n\n\n### Correction ###\n"""
|
61 |
input_ids = tokenizer.encode(prompt, return_tensors="pt").to(device)
|
62 |
|
63 |
-
|
64 |
-
|
65 |
-
with ThreadPoolExecutor(max_workers=num_threads) as executor:
|
66 |
-
future = executor.submit(
|
67 |
-
model.generate,
|
68 |
-
input_ids,
|
69 |
max_new_tokens=max_new_tokens,
|
70 |
pad_token_id=tokenizer.eos_token_id,
|
71 |
top_k=50,
|
72 |
num_return_sequences=1,
|
73 |
do_sample=False
|
74 |
)
|
75 |
-
output = future.result()
|
76 |
|
77 |
result = tokenizer.decode(output[0], skip_special_tokens=True)
|
78 |
return result.split("### Correction ###")[1].strip()
|
|
|
4 |
import torch
|
5 |
import gradio as gr
|
6 |
import difflib
|
7 |
+
import spaces
|
8 |
from concurrent.futures import ThreadPoolExecutor
|
9 |
import os
|
10 |
|
|
|
57 |
|
58 |
return chunks
|
59 |
|
60 |
+
@spaces.GPU
|
61 |
def ocr_correction(prompt, max_new_tokens=600, num_threads=os.cpu_count()):
|
62 |
prompt = f"""### Text ###\n{prompt}\n\n\n### Correction ###\n"""
|
63 |
input_ids = tokenizer.encode(prompt, return_tensors="pt").to(device)
|
64 |
|
65 |
+
output = model.generate(input_ids,
|
|
|
|
|
|
|
|
|
|
|
66 |
max_new_tokens=max_new_tokens,
|
67 |
pad_token_id=tokenizer.eos_token_id,
|
68 |
top_k=50,
|
69 |
num_return_sequences=1,
|
70 |
do_sample=False
|
71 |
)
|
|
|
72 |
|
73 |
result = tokenizer.decode(output[0], skip_special_tokens=True)
|
74 |
return result.split("### Correction ###")[1].strip()
|