JustinLin610 commited on
Commit
2915058
1 Parent(s): 7883098
Files changed (2) hide show
  1. app.py +9 -9
  2. data/mm_data/ocr_dataset.py +10 -4
app.py CHANGED
@@ -70,7 +70,7 @@ def get_images(img: str, reader: ReaderLite, **kwargs):
70
  return results
71
 
72
 
73
- def draw_boxes(image, bounds, color='red', width=2):
74
  draw = ImageDraw.Draw(image)
75
  for i, bound in enumerate(bounds):
76
  p0, p1, p2, p3 = bound
@@ -102,7 +102,7 @@ def patch_resize_transform(patch_image_size=480, is_document=False):
102
  _patch_resize_transform = transforms.Compose(
103
  [
104
  lambda image: ocr_resize(
105
- image, patch_image_size, is_document=is_document
106
  ),
107
  transforms.ToTensor(),
108
  transforms.Normalize(mean=mean, std=std),
@@ -113,7 +113,7 @@ def patch_resize_transform(patch_image_size=480, is_document=False):
113
 
114
 
115
  reader = ReaderLite()
116
- overrides={"eval_cider": False, "beam": 8, "max_len_b": 128, "patch_image_size": 480,
117
  "orig_patch_image_size": 224, "no_repeat_ngram_size": 0, "seed": 7}
118
  models, cfg, task = checkpoint_utils.load_model_ensemble_and_task(
119
  utils.split_paths('checkpoints/ocr_general_clean.pt'),
@@ -163,9 +163,9 @@ def apply_half(t):
163
  return t
164
 
165
 
166
- def ocr(img):
167
- out_img = Image.open(img)
168
- results = get_images(img, reader)
169
  box_list, image_list = zip(*results)
170
  draw_boxes(out_img, box_list)
171
 
@@ -191,9 +191,9 @@ description = "Gradio Demo for OFA-OCR. Upload your own image or click any one o
191
  article = "<p style='text-align: center'><a href='https://github.com/OFA-Sys/OFA' target='_blank'>OFA Github " \
192
  "Repo</a></p> "
193
  examples = [['lihe.png']]
194
- io = gr.Interface(fn=ocr, inputs=gr.inputs.Image(type='filepath'),
195
- outputs=[gr.outputs.Image(type='pil'), gr.outputs.Textbox(label="OCR result")],
196
  title=title, description=description, article=article, examples=examples,
197
- allow_flagging=False, allow_screenshot=False)
198
  io.launch(cache_examples=True)
199
 
 
70
  return results
71
 
72
 
73
+ def draw_boxes(image, bounds, color='red', width=10):
74
  draw = ImageDraw.Draw(image)
75
  for i, bound in enumerate(bounds):
76
  p0, p1, p2, p3 = bound
 
102
  _patch_resize_transform = transforms.Compose(
103
  [
104
  lambda image: ocr_resize(
105
+ image, patch_image_size, is_document=is_document, split='test',
106
  ),
107
  transforms.ToTensor(),
108
  transforms.Normalize(mean=mean, std=std),
 
113
 
114
 
115
  reader = ReaderLite()
116
+ overrides={"eval_cider": False, "beam": 4, "max_len_b": 32, "patch_image_size": 480,
117
  "orig_patch_image_size": 224, "no_repeat_ngram_size": 0, "seed": 7}
118
  models, cfg, task = checkpoint_utils.load_model_ensemble_and_task(
119
  utils.split_paths('checkpoints/ocr_general_clean.pt'),
 
163
  return t
164
 
165
 
166
+ def ocr(Image):
167
+ out_img = Image.open(Image)
168
+ results = get_images(Image, reader, link_threshold=0.2)
169
  box_list, image_list = zip(*results)
170
  draw_boxes(out_img, box_list)
171
 
 
191
  article = "<p style='text-align: center'><a href='https://github.com/OFA-Sys/OFA' target='_blank'>OFA Github " \
192
  "Repo</a></p> "
193
  examples = [['lihe.png']]
194
+ io = gr.Interface(fn=ocr, inputs=gr.inputs.Image(type='filepath', label='Image'),
195
+ outputs=[gr.outputs.Image(type='pil', label='Image'), gr.outputs.Textbox(label="OCR result")],
196
  title=title, description=description, article=article, examples=examples,
197
+ allow_flagging='never', allow_screenshot=False)
198
  io.launch(cache_examples=True)
199
 
data/mm_data/ocr_dataset.py CHANGED
@@ -82,7 +82,7 @@ def collate(samples, pad_idx, eos_idx):
82
  return batch
83
 
84
 
85
- def ocr_resize(img, patch_image_size, is_document=False):
86
  img = img.convert("RGB")
87
  width, height = img.size
88
 
@@ -92,13 +92,19 @@ def ocr_resize(img, patch_image_size, is_document=False):
92
  if width >= height:
93
  new_width = max(64, patch_image_size)
94
  new_height = max(64, int(patch_image_size * (height / width)))
95
- top = random.randint(0, patch_image_size - new_height)
 
 
 
96
  bottom = patch_image_size - new_height - top
97
  left, right = 0, 0
98
  else:
99
  new_height = max(64, patch_image_size)
100
  new_width = max(64, int(patch_image_size * (width / height)))
101
- left = random.randint(0, patch_image_size - new_width)
 
 
 
102
  right = patch_image_size - new_width - left
103
  top, bottom = 0, 0
104
 
@@ -151,7 +157,7 @@ class OcrDataset(OFADataset):
151
  self.patch_resize_transform = transforms.Compose(
152
  [
153
  lambda image: ocr_resize(
154
- image, patch_image_size, is_document=is_document
155
  ),
156
  transforms.ToTensor(),
157
  transforms.Normalize(mean=mean, std=std),
 
82
  return batch
83
 
84
 
85
+ def ocr_resize(img, patch_image_size, is_document=False, split='train'):
86
  img = img.convert("RGB")
87
  width, height = img.size
88
 
 
92
  if width >= height:
93
  new_width = max(64, patch_image_size)
94
  new_height = max(64, int(patch_image_size * (height / width)))
95
+ if split != 'train':
96
+ top = int((patch_image_size - new_height) // 2)
97
+ else:
98
+ top = random.randint(0, patch_image_size - new_height)
99
  bottom = patch_image_size - new_height - top
100
  left, right = 0, 0
101
  else:
102
  new_height = max(64, patch_image_size)
103
  new_width = max(64, int(patch_image_size * (width / height)))
104
+ if split != 'train':
105
+ left = int((patch_image_size - new_width) // 2)
106
+ else:
107
+ left = random.randint(0, patch_image_size - new_width)
108
  right = patch_image_size - new_width - left
109
  top, bottom = 0, 0
110
 
 
157
  self.patch_resize_transform = transforms.Compose(
158
  [
159
  lambda image: ocr_resize(
160
+ image, patch_image_size, is_document=is_document, split=split,
161
  ),
162
  transforms.ToTensor(),
163
  transforms.Normalize(mean=mean, std=std),