phamvi856 commited on
Commit
fd20fcc
·
1 Parent(s): 9255af8

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +5 -17
app.py CHANGED
@@ -57,15 +57,13 @@ def process_image(image):
57
  width, height = image.size
58
 
59
  # Encode image
60
- encoding = processor(image, truncation=True, return_offsets_mapping=True, return_tensors="pt")
61
  input_ids = encoding.input_ids.to(device)
62
  attention_mask = encoding.attention_mask.to(device)
63
  bbox = encoding.bbox.to(device)
64
 
65
- # Predict token labels
66
- with torch.no_grad():
67
- outputs = model(input_ids=input_ids, bbox=bbox, attention_mask=attention_mask)
68
-
69
  predicted_labels = outputs.logits.argmax(dim=2).squeeze().tolist()
70
 
71
  # Extract content from boxes
@@ -74,12 +72,7 @@ def process_image(image):
74
  predicted_label = id2label[predicted_labels[idx]]
75
  box_width = np.array(box)[2] - np.array(box)[0]
76
  box_height = np.array(box)[3] - np.array(box)[1]
77
- normalized_box = [
78
- box[0] * width / 1000,
79
- box[1] * height / 1000,
80
- box_width * width / 1000,
81
- box_height * height / 1000,
82
- ]
83
  extracted_content[predicted_label] = image.crop(normalized_box).copy()
84
 
85
  # Draw predictions over the image
@@ -89,12 +82,7 @@ def process_image(image):
89
  predicted_label = iob_to_label(id2label[prediction])
90
  box_width = np.array(box)[2] - np.array(box)[0]
91
  box_height = np.array(box)[3] - np.array(box)[1]
92
- normalized_box = [
93
- box[0] * width / 1000,
94
- box[1] * height / 1000,
95
- box_width * width / 1000,
96
- box_height * height / 1000,
97
- ]
98
  draw.rectangle(normalized_box, outline=label2color[predicted_label])
99
  draw.text((normalized_box[0] + 10, normalized_box[1] - 10), text=predicted_label, fill=label2color[predicted_label], font=font)
100
 
 
57
  width, height = image.size
58
 
59
  # Encode image
60
+ encoding = processor(image, truncation=True, padding="max_length", max_length=512, return_tensors="pt")
61
  input_ids = encoding.input_ids.to(device)
62
  attention_mask = encoding.attention_mask.to(device)
63
  bbox = encoding.bbox.to(device)
64
 
65
+ # Inference
66
+ outputs = model(input_ids=input_ids, bbox=bbox, attention_mask=attention_mask)
 
 
67
  predicted_labels = outputs.logits.argmax(dim=2).squeeze().tolist()
68
 
69
  # Extract content from boxes
 
72
  predicted_label = id2label[predicted_labels[idx]]
73
  box_width = np.array(box)[2] - np.array(box)[0]
74
  box_height = np.array(box)[3] - np.array(box)[1]
75
+ normalized_box = unnormalize_box(box, width, height)
 
 
 
 
 
76
  extracted_content[predicted_label] = image.crop(normalized_box).copy()
77
 
78
  # Draw predictions over the image
 
82
  predicted_label = iob_to_label(id2label[prediction])
83
  box_width = np.array(box)[2] - np.array(box)[0]
84
  box_height = np.array(box)[3] - np.array(box)[1]
85
+ normalized_box = unnormalize_box(box, width, height)
 
 
 
 
 
86
  draw.rectangle(normalized_box, outline=label2color[predicted_label])
87
  draw.text((normalized_box[0] + 10, normalized_box[1] - 10), text=predicted_label, fill=label2color[predicted_label], font=font)
88