JustinLin610 commited on
Commit
6cf5e8c
•
1 Parent(s): 3006ddf
Files changed (1) hide show
  1. app.py +3 -3
app.py CHANGED
@@ -136,8 +136,8 @@ def patch_resize_transform(patch_image_size=480, is_document=False):
136
 
137
 
138
  # Construct input for caption task
139
- def construct_sample(image: Image, patch_image_size=480):
140
- patch_image = patch_resize_transform(patch_image_size)(image).unsqueeze(0)
141
  patch_mask = torch.tensor([True])
142
  src_text = encode_text("图片上的文字是什么?", append_bos=True, append_eos=True).unsqueeze(0)
143
  src_length = torch.LongTensor([s.ne(pad_idx).long().sum() for s in src_text])
@@ -177,7 +177,7 @@ def ocr(img, task_type):
177
  ocr_result = []
178
  for i, (box, image) in enumerate(zip(box_list, image_list)):
179
  image = Image.fromarray(image)
180
- sample = construct_sample(task, image, cfg.task.patch_image_size)
181
  sample = utils.move_to_cuda(sample) if use_cuda else sample
182
  sample = utils.apply_to_sample(apply_half, sample) if use_fp16 else sample
183
 
 
136
 
137
 
138
  # Construct input for caption task
139
+ def construct_sample(image: Image, patch_image_size=480, is_document=False):
140
+ patch_image = patch_resize_transform(patch_image_size, is_document=is_document)(image).unsqueeze(0)
141
  patch_mask = torch.tensor([True])
142
  src_text = encode_text("图片上的文字是什么?", append_bos=True, append_eos=True).unsqueeze(0)
143
  src_length = torch.LongTensor([s.ne(pad_idx).long().sum() for s in src_text])
 
177
  ocr_result = []
178
  for i, (box, image) in enumerate(zip(box_list, image_list)):
179
  image = Image.fromarray(image)
180
+ sample = construct_sample(image, cfg.task.patch_image_size, is_document=(task_type=='Document'))
181
  sample = utils.move_to_cuda(sample) if use_cuda else sample
182
  sample = utils.apply_to_sample(apply_half, sample) if use_fp16 else sample
183