JustinLin610 commited on
Commit
3006ddf
•
1 Parent(s): 332c912
Files changed (1) hide show
  1. app.py +28 -32
app.py CHANGED
@@ -41,6 +41,31 @@ Rect = Tuple[int, int, int, int]
41
  FourPoint = Tuple[Tuple[int, int], Tuple[int, int], Tuple[int, int], Tuple[int, int]]
42
 
43
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
44
  def four_point_transform(image: np.ndarray, rect: FourPoint) -> np.ndarray:
45
  (tl, tr, br, bl) = rect
46
 
@@ -81,10 +106,7 @@ def draw_boxes(image, bounds, color='red', width=4):
81
  return image
82
 
83
 
84
- def encode_text(task, text, length=None, append_bos=False, append_eos=False):
85
- bos_item = torch.LongTensor([task.src_dict.bos()])
86
- eos_item = torch.LongTensor([task.src_dict.eos()])
87
-
88
  s = task.tgt_dict.encode_line(
89
  line=task.bpe.encode(text),
90
  add_if_not_exist=False,
@@ -113,37 +135,11 @@ def patch_resize_transform(patch_image_size=480, is_document=False):
113
  return _patch_resize_transform
114
 
115
 
116
- reader = ReaderLite(gpu=True)
117
-
118
- overrides={"eval_cider": False, "beam": 5, "max_len_b": 64, "patch_image_size": 480,
119
- "orig_patch_image_size": 224, "no_repeat_ngram_size": 0, "seed": 42}
120
- models, cfg, task = checkpoint_utils.load_model_ensemble_and_task(
121
- utils.split_paths('checkpoints/ocr_general_clean.pt'),
122
- arg_overrides=overrides
123
- )
124
-
125
- # Move models to GPU
126
- for model in models:
127
- model.eval()
128
- if use_fp16:
129
- model.half()
130
- if use_cuda and not cfg.distributed_training.pipeline_model_parallel:
131
- model.cuda()
132
- model.prepare_for_inference_(cfg)
133
-
134
- # Initialize generator
135
- generator = task.build_generator(models, cfg.generation)
136
-
137
- bos_item = torch.LongTensor([task.src_dict.bos()])
138
- eos_item = torch.LongTensor([task.src_dict.eos()])
139
- pad_idx = task.src_dict.pad()
140
-
141
-
142
  # Construct input for caption task
143
- def construct_sample(task, image: Image, patch_image_size=480):
144
  patch_image = patch_resize_transform(patch_image_size)(image).unsqueeze(0)
145
  patch_mask = torch.tensor([True])
146
- src_text = encode_text(task, "图片上的文字是什么?", append_bos=True, append_eos=True).unsqueeze(0)
147
  src_length = torch.LongTensor([s.ne(pad_idx).long().sum() for s in src_text])
148
  sample = {
149
  "id":np.array(['42']),
 
41
  FourPoint = Tuple[Tuple[int, int], Tuple[int, int], Tuple[int, int], Tuple[int, int]]
42
 
43
 
44
+ reader = ReaderLite(gpu=True)
45
+ overrides={"eval_cider": False, "beam": 5, "max_len_b": 64, "patch_image_size": 480,
46
+ "orig_patch_image_size": 224, "no_repeat_ngram_size": 0, "seed": 42}
47
+ models, cfg, task = checkpoint_utils.load_model_ensemble_and_task(
48
+ utils.split_paths('checkpoints/ocr_general_clean.pt'),
49
+ arg_overrides=overrides
50
+ )
51
+
52
+ # Move models to GPU
53
+ for model in models:
54
+ model.eval()
55
+ if use_fp16:
56
+ model.half()
57
+ if use_cuda and not cfg.distributed_training.pipeline_model_parallel:
58
+ model.cuda()
59
+ model.prepare_for_inference_(cfg)
60
+
61
+ # Initialize generator
62
+ generator = task.build_generator(models, cfg.generation)
63
+
64
+ bos_item = torch.LongTensor([task.src_dict.bos()])
65
+ eos_item = torch.LongTensor([task.src_dict.eos()])
66
+ pad_idx = task.src_dict.pad()
67
+
68
+
69
  def four_point_transform(image: np.ndarray, rect: FourPoint) -> np.ndarray:
70
  (tl, tr, br, bl) = rect
71
 
 
106
  return image
107
 
108
 
109
+ def encode_text(text, length=None, append_bos=False, append_eos=False):
 
 
 
110
  s = task.tgt_dict.encode_line(
111
  line=task.bpe.encode(text),
112
  add_if_not_exist=False,
 
135
  return _patch_resize_transform
136
 
137
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
138
  # Construct input for caption task
139
+ def construct_sample(image: Image, patch_image_size=480):
140
  patch_image = patch_resize_transform(patch_image_size)(image).unsqueeze(0)
141
  patch_mask = torch.tensor([True])
142
+ src_text = encode_text("图片上的文字是什么?", append_bos=True, append_eos=True).unsqueeze(0)
143
  src_length = torch.LongTensor([s.ne(pad_idx).long().sum() for s in src_text])
144
  sample = {
145
  "id":np.array(['42']),