Pclanglais commited on
Commit
f16bf84
1 Parent(s): ba05a34

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +3 -7
app.py CHANGED
@@ -4,6 +4,7 @@ from transformers import GPT2LMHeadModel, GPT2Tokenizer
4
  import torch
5
  import gradio as gr
6
  import difflib
 
7
  from concurrent.futures import ThreadPoolExecutor
8
  import os
9
 
@@ -56,23 +57,18 @@ def split_text(text, max_tokens=400):
56
 
57
  return chunks
58
 
 
59
  def ocr_correction(prompt, max_new_tokens=600, num_threads=os.cpu_count()):
60
  prompt = f"""### Text ###\n{prompt}\n\n\n### Correction ###\n"""
61
  input_ids = tokenizer.encode(prompt, return_tensors="pt").to(device)
62
 
63
- torch.set_num_threads(num_threads)
64
-
65
- with ThreadPoolExecutor(max_workers=num_threads) as executor:
66
- future = executor.submit(
67
- model.generate,
68
- input_ids,
69
  max_new_tokens=max_new_tokens,
70
  pad_token_id=tokenizer.eos_token_id,
71
  top_k=50,
72
  num_return_sequences=1,
73
  do_sample=False
74
  )
75
- output = future.result()
76
 
77
  result = tokenizer.decode(output[0], skip_special_tokens=True)
78
  return result.split("### Correction ###")[1].strip()
 
4
  import torch
5
  import gradio as gr
6
  import difflib
7
+ import spaces
8
  from concurrent.futures import ThreadPoolExecutor
9
  import os
10
 
 
57
 
58
  return chunks
59
 
60
+ @spaces.GPU
61
  def ocr_correction(prompt, max_new_tokens=600, num_threads=os.cpu_count()):
62
  prompt = f"""### Text ###\n{prompt}\n\n\n### Correction ###\n"""
63
  input_ids = tokenizer.encode(prompt, return_tensors="pt").to(device)
64
 
65
+ output = model.generate(input_ids,
 
 
 
 
 
66
  max_new_tokens=max_new_tokens,
67
  pad_token_id=tokenizer.eos_token_id,
68
  top_k=50,
69
  num_return_sequences=1,
70
  do_sample=False
71
  )
 
72
 
73
  result = tokenizer.decode(output[0], skip_special_tokens=True)
74
  return result.split("### Correction ###")[1].strip()