Spaces:
Sleeping
Sleeping
Pclanglais
commited on
Commit
•
1fca231
1
Parent(s):
fa86caf
Update app.py
Browse files
app.py
CHANGED
@@ -11,6 +11,7 @@ import shutil
|
|
11 |
import requests
|
12 |
import pandas as pd
|
13 |
import difflib
|
|
|
14 |
|
15 |
# Define the device
|
16 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
@@ -168,24 +169,32 @@ def split_text(text, max_tokens=500):
|
|
168 |
|
169 |
|
170 |
# Function to generate text
|
171 |
-
def ocr_correction(prompt, max_new_tokens=600):
|
172 |
-
|
173 |
prompt = f"""### Text ###\n{prompt}\n\n\n### Correction ###\n"""
|
174 |
input_ids = tokenizer.encode(prompt, return_tensors="pt").to(device)
|
175 |
|
|
|
|
|
|
|
176 |
# Generate text
|
177 |
-
|
178 |
-
|
179 |
-
|
180 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
181 |
|
182 |
# Decode and return the generated text
|
183 |
result = tokenizer.decode(output[0], skip_special_tokens=True)
|
184 |
-
|
185 |
print(result)
|
186 |
-
|
187 |
result = result.split("### Correction ###")[1]
|
188 |
-
|
189 |
return result
|
190 |
|
191 |
# OCR Correction Class
|
|
|
11 |
import requests
|
12 |
import pandas as pd
|
13 |
import difflib
|
14 |
+
from concurrent.futures import ThreadPoolExecutor
|
15 |
|
16 |
# Define the device
|
17 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
|
|
169 |
|
170 |
|
171 |
# Function to generate text
|
172 |
+
def ocr_correction(prompt, max_new_tokens=600, num_threads=os.cpu_count()):
|
|
|
173 |
prompt = f"""### Text ###\n{prompt}\n\n\n### Correction ###\n"""
|
174 |
input_ids = tokenizer.encode(prompt, return_tensors="pt").to(device)
|
175 |
|
176 |
+
# Set the number of threads for PyTorch
|
177 |
+
torch.set_num_threads(num_threads)
|
178 |
+
|
179 |
# Generate text
|
180 |
+
with ThreadPoolExecutor(max_workers=num_threads) as executor:
|
181 |
+
future = executor.submit(
|
182 |
+
model.generate,
|
183 |
+
input_ids,
|
184 |
+
max_new_tokens=max_new_tokens,
|
185 |
+
pad_token_id=tokenizer.eos_token_id,
|
186 |
+
top_k=50,
|
187 |
+
num_return_sequences=1,
|
188 |
+
do_sample=True,
|
189 |
+
temperature=0.7
|
190 |
+
)
|
191 |
+
output = future.result()
|
192 |
|
193 |
# Decode and return the generated text
|
194 |
result = tokenizer.decode(output[0], skip_special_tokens=True)
|
|
|
195 |
print(result)
|
196 |
+
|
197 |
result = result.split("### Correction ###")[1]
|
|
|
198 |
return result
|
199 |
|
200 |
# OCR Correction Class
|