Spaces:
Sleeping
Sleeping
Pclanglais
commited on
Commit
•
dd838d3
1
Parent(s):
ffbf266
Update app.py
Browse files
app.py
CHANGED
@@ -13,22 +13,19 @@ import pandas as pd
|
|
13 |
import difflib
|
14 |
from concurrent.futures import ThreadPoolExecutor
|
15 |
|
16 |
-
# Define the device
|
17 |
-
device = "cuda" if torch.cuda.is_available() else "cpu"
|
18 |
-
|
19 |
# OCR Correction Model
|
20 |
ocr_model_name = "PleIAs/OCRonos-Vintage"
|
21 |
|
22 |
import torch
|
23 |
from transformers import GPT2LMHeadModel, GPT2Tokenizer
|
24 |
|
|
|
|
|
25 |
# Load pre-trained model and tokenizer
|
26 |
model_name = "PleIAs/OCRonos-Vintage"
|
27 |
model = GPT2LMHeadModel.from_pretrained(model_name)
|
28 |
tokenizer = GPT2Tokenizer.from_pretrained(model_name)
|
29 |
|
30 |
-
# Set the device to GPU if available, otherwise use CPU
|
31 |
-
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
32 |
model.to(device)
|
33 |
|
34 |
# CSS for formatting
|
@@ -169,7 +166,9 @@ def split_text(text, max_tokens=500):
|
|
169 |
|
170 |
|
171 |
# Function to generate text
|
172 |
-
|
|
|
|
|
173 |
prompt = f"""### Text ###\n{prompt}\n\n\n### Correction ###\n"""
|
174 |
input_ids = tokenizer.encode(prompt, return_tensors="pt").to(device)
|
175 |
|
@@ -177,9 +176,7 @@ def ocr_correction(prompt, max_new_tokens=600, num_threads=os.cpu_count()):
|
|
177 |
torch.set_num_threads(num_threads)
|
178 |
|
179 |
# Generate text
|
180 |
-
|
181 |
-
future = executor.submit(
|
182 |
-
model.generate,
|
183 |
input_ids,
|
184 |
max_new_tokens=max_new_tokens,
|
185 |
pad_token_id=tokenizer.eos_token_id,
|
@@ -188,8 +185,6 @@ def ocr_correction(prompt, max_new_tokens=600, num_threads=os.cpu_count()):
|
|
188 |
do_sample=True,
|
189 |
temperature=0.7
|
190 |
)
|
191 |
-
output = future.result()
|
192 |
-
|
193 |
# Decode and return the generated text
|
194 |
result = tokenizer.decode(output[0], skip_special_tokens=True)
|
195 |
print(result)
|
|
|
13 |
import difflib
|
14 |
from concurrent.futures import ThreadPoolExecutor
|
15 |
|
|
|
|
|
|
|
16 |
# OCR Correction Model
|
17 |
ocr_model_name = "PleIAs/OCRonos-Vintage"
|
18 |
|
19 |
import torch
|
20 |
from transformers import GPT2LMHeadModel, GPT2Tokenizer
|
21 |
|
22 |
+
device = "cuda"
|
23 |
+
|
24 |
# Load pre-trained model and tokenizer
|
25 |
model_name = "PleIAs/OCRonos-Vintage"
|
26 |
model = GPT2LMHeadModel.from_pretrained(model_name)
|
27 |
tokenizer = GPT2Tokenizer.from_pretrained(model_name)
|
28 |
|
|
|
|
|
29 |
model.to(device)
|
30 |
|
31 |
# CSS for formatting
|
|
|
166 |
|
167 |
|
168 |
# Function to generate text
|
169 |
+
@spaces.GPU
|
170 |
+
def ocr_correction(prompt, max_new_tokens=500):
|
171 |
+
|
172 |
prompt = f"""### Text ###\n{prompt}\n\n\n### Correction ###\n"""
|
173 |
input_ids = tokenizer.encode(prompt, return_tensors="pt").to(device)
|
174 |
|
|
|
176 |
torch.set_num_threads(num_threads)
|
177 |
|
178 |
# Generate text
|
179 |
+
output = model.generate,
|
|
|
|
|
180 |
input_ids,
|
181 |
max_new_tokens=max_new_tokens,
|
182 |
pad_token_id=tokenizer.eos_token_id,
|
|
|
185 |
do_sample=True,
|
186 |
temperature=0.7
|
187 |
)
|
|
|
|
|
188 |
# Decode and return the generated text
|
189 |
result = tokenizer.decode(output[0], skip_special_tokens=True)
|
190 |
print(result)
|