Karthika0308 commited on
Commit
c81bba7
·
verified ·
1 Parent(s): 1b77aaf

Upload 2 files

Browse files
Files changed (2) hide show
  1. app.py +170 -0
  2. inference.py +198 -0
app.py ADDED
@@ -0,0 +1,170 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ os.environ["CUDA_VISIBLE_DEVICES"] = "0"
3
+ import torch
4
+ import gradio as gr
5
+ from gradio_image_prompter import ImagePrompter
6
+ from torch.nn import DataParallel
7
+ from models.counter_infer import build_model
8
+ from utils.arg_parser import get_argparser
9
+ from utils.data import resize_and_pad
10
+ import torchvision.ops as ops
11
+ from torchvision import transforms as T
12
+ from PIL import Image, ImageDraw, ImageFont
13
+ import numpy as np
14
+
15
+ # Load model (once, to avoid reloading)
16
+ def load_model():
17
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
18
+ args = get_argparser().parse_args()
19
+ args.zero_shot = True
20
+ model = DataParallel(build_model(args).to(device))
21
+ model.load_state_dict(torch.load('CNTQG_multitrain_ca44.pth', weights_only=True)['model'], strict=False)
22
+ model.eval()
23
+ return model, device
24
+
25
+ model, device = load_model()
26
+
27
+ # **Function to Process Image Once**
28
+ def process_image_once(inputs, enable_mask):
29
+ model.module.return_masks = enable_mask
30
+
31
+ image = inputs['image']
32
+ drawn_boxes = inputs['points']
33
+ image_tensor = torch.tensor(image).to(device)
34
+ image_tensor = image_tensor.permute(2, 0, 1).float() / 255.0
35
+ image_tensor = T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])(image_tensor)
36
+
37
+ bboxes_tensor = torch.tensor([[box[0], box[1], box[3], box[4]] for box in drawn_boxes], dtype=torch.float32).to(device)
38
+
39
+ img, bboxes, scale = resize_and_pad(image_tensor, bboxes_tensor, size=1024.0)
40
+ img = img.unsqueeze(0).to(device)
41
+ bboxes = bboxes.unsqueeze(0).to(device)
42
+
43
+ with torch.no_grad():
44
+ outputs, _, _, _, masks = model(img, bboxes)
45
+
46
+ return image, outputs, masks, img, scale, drawn_boxes
47
+
48
+ # **Post-process and Update Output**
49
+ def post_process(image, outputs, masks, img, scale, drawn_boxes, enable_mask, threshold):
50
+ idx = 0
51
+ threshold = 1/threshold
52
+ keep = ops.nms(outputs[idx]['pred_boxes'][outputs[idx]['box_v'] > outputs[idx]['box_v'].max() / threshold],
53
+ outputs[idx]['box_v'][outputs[idx]['box_v'] > outputs[idx]['box_v'].max() / threshold], 0.5)
54
+
55
+ pred_boxes = outputs[idx]['pred_boxes'][outputs[idx]['box_v'] > outputs[idx]['box_v'].max() / threshold][keep]
56
+ pred_boxes = torch.clamp(pred_boxes, 0, 1)
57
+
58
+ pred_boxes = (pred_boxes.cpu() / scale * img.shape[-1]).tolist()
59
+
60
+ image = Image.fromarray((image).astype(np.uint8))
61
+
62
+ if enable_mask:
63
+ from matplotlib import pyplot as plt
64
+ masks_ = masks[idx][(outputs[idx]['box_v'] > outputs[idx]['box_v'].max() / threshold)[0]]
65
+ N_masks = masks_.shape[0]
66
+ indices = torch.randint(1, N_masks + 1, (1, N_masks), device=masks_.device).view(-1, 1, 1)
67
+ masks = (masks_ * indices).sum(dim=0)
68
+ mask_display = (
69
+ T.Resize((int(img.shape[2] / scale), int(img.shape[3] / scale)), interpolation=T.InterpolationMode.NEAREST)(
70
+ masks.cpu().unsqueeze(0))[0])[:image.size[1], :image.size[0]]
71
+ cmap = plt.cm.tab20
72
+ norm = plt.Normalize(vmin=0, vmax=N_masks)
73
+ del masks
74
+ del masks_
75
+ del outputs
76
+ rgba_image = cmap(norm(mask_display))
77
+ rgba_image[mask_display == 0, -1] = 0
78
+ rgba_image[mask_display != 0, -1] = 0.5
79
+
80
+ overlay = Image.fromarray((rgba_image * 255).astype(np.uint8), mode="RGBA")
81
+ image = image.convert("RGBA")
82
+ image = Image.alpha_composite(image, overlay)
83
+
84
+
85
+ draw = ImageDraw.Draw(image)
86
+ for box in pred_boxes:
87
+ draw.rectangle([box[0], box[1], box[2], box[3]], outline="orange", width=5)
88
+ # for box in drawn_boxes:
89
+ # draw.rectangle([box[0], box[1], box[3], box[4]], outline="red", width=3)
90
+
91
+ width, height = image.size
92
+ square_size = int(0.05 * width)
93
+ x1, y1 = 10, height - square_size - 10
94
+ x2, y2 = x1 + square_size, y1 + square_size
95
+
96
+ # draw.rectangle([x1, y1, x2, y2], outline="black", fill="black", width=1)
97
+ # font = ImageFont.load_default()
98
+ # txt = str(len(pred_boxes))
99
+ # w = draw.textlength(txt, font=font)
100
+ # text_x = x1 + (square_size - w) / 2
101
+ # text_y = y1 + (square_size - 10) / 2
102
+ # draw.text((text_x, text_y), txt, fill="white", font=font)
103
+
104
+ return image, len(pred_boxes)
105
+
106
+
107
+ iface = gr.Blocks()
108
+
109
+ with iface:
110
+ # Store intermediate states
111
+ image_input = gr.State()
112
+ outputs_state = gr.State()
113
+ masks_state = gr.State()
114
+ img_state = gr.State()
115
+ scale_state = gr.State()
116
+ drawn_boxes_state = gr.State()
117
+
118
+ # UI Layout: Input Section
119
+ with gr.Row():
120
+ image_prompter = ImagePrompter()
121
+ image_output = gr.Image(type="pil")
122
+
123
+
124
+ # UI Layout: Output Section
125
+ with gr.Row():
126
+ count_output = gr.Number(label="Total Count")
127
+ enable_mask = gr.Checkbox(label="Predict masks", value=True) # Mask enabled by default
128
+ threshold = gr.Slider(0.05, 0.95, value=0.33, step=0.01, label="Threshold") # Updated range and default
129
+
130
+
131
+ # Create the 'Count' button
132
+ count_button = gr.Button("Count")
133
+
134
+ # Process image once when "Count" button is pressed
135
+ def initial_process(inputs, enable_mask, threshold):
136
+ # Perform inference once
137
+ image, outputs, masks, img, scale, drawn_boxes = process_image_once(inputs, enable_mask)
138
+
139
+ # Save intermediate states
140
+ return (
141
+ *post_process(image, outputs, masks, img, scale, drawn_boxes, enable_mask, threshold), # Processed outputs
142
+ image, outputs, masks, img, scale, drawn_boxes # Store in states for later use
143
+ )
144
+
145
+ # Update image and count when the threshold slider changes (post-process only)
146
+ def update_threshold(threshold, image, outputs, masks, img, scale, drawn_boxes, enable_mask):
147
+ return post_process(image, outputs, masks, img, scale, drawn_boxes, enable_mask, threshold)
148
+
149
+ # Run initial inference and post-process when "Count" button is clicked
150
+ count_button.click(
151
+ initial_process,
152
+ [image_prompter, enable_mask, threshold], # Inputs
153
+ [image_output, count_output, image_input, outputs_state, masks_state, img_state, scale_state, drawn_boxes_state] # Outputs + States
154
+ )
155
+
156
+ # Adjust the output dynamically based on the threshold slider (no re-inference)
157
+ threshold.change(
158
+ update_threshold,
159
+ [threshold, image_input, outputs_state, masks_state, img_state, scale_state, drawn_boxes_state, enable_mask],
160
+ [image_output, count_output]
161
+ )
162
+
163
+ enable_mask.change(
164
+ update_threshold,
165
+ [threshold, image_input, outputs_state, masks_state, img_state, scale_state, drawn_boxes_state, enable_mask],
166
+ [image_output, count_output]
167
+ )
168
+
169
+ iface.launch(share=True)
170
+
inference.py ADDED
@@ -0,0 +1,198 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import json
3
+ import math
4
+ import os
5
+
6
+ import numpy as np
7
+ import skimage
8
+ import torch
9
+ from torch.nn import DataParallel
10
+ from torch.utils.data import DataLoader
11
+ from torchvision import ops
12
+ from torchvision.transforms import Resize
13
+ from tqdm import tqdm
14
+ from models.counter_infer import build_model
15
+ from models.matcher import build_matcher
16
+ from utils.arg_parser import get_argparser
17
+ from utils.box_ops import BoxList
18
+ from utils.data import FSC147DATASET, pad_collate_test
19
+ from utils.losses import SetCriterion
20
+
21
+
22
+ @torch.no_grad()
23
+ def evaluate(args):
24
+ gpu=0
25
+ torch.cuda.set_device(gpu)
26
+ device = torch.device(gpu)
27
+
28
+ model = DataParallel(
29
+ build_model(args).to(device),
30
+ device_ids=[gpu],
31
+ output_device=gpu
32
+ )
33
+
34
+ state_dict = torch.load(os.path.join(args.model_path, f'{args.model_name}.pth'))['model']
35
+ state_dict = {k if 'module.' in k else 'module.' + k: v for k, v in state_dict.items()}
36
+ model.load_state_dict(state_dict, strict=False)
37
+
38
+ for split in ['val', 'test']:
39
+ test = FSC147DATASET(
40
+ args.data_path,
41
+ args.image_size,
42
+ split=split,
43
+ num_objects=args.num_objects,
44
+ tiling_p=args.tiling_p,
45
+ return_ids=True,
46
+ training=False
47
+ )
48
+ test_loader = DataLoader(
49
+ test,
50
+ batch_size=args.batch_size,
51
+ drop_last=False,
52
+ num_workers=args.num_workers,
53
+ collate_fn=pad_collate_test,
54
+ )
55
+ ae = torch.tensor(0.0).to(device)
56
+ se = torch.tensor(0.0).to(device)
57
+ model.eval()
58
+ matcher = build_matcher(args)
59
+ criterion = SetCriterion(0, matcher, {"loss_giou":args.giou_loss_coef}, ["bboxes", "ce"], focal_alpha=args.focal_alpha)
60
+ criterion.to(device)
61
+
62
+
63
+ predictions = dict()
64
+ predictions["categories"] = [{"name": "fg", "id": 1}]
65
+ predictions["images"] = list()
66
+ predictions["annotations"] = list()
67
+ anno_id = 1
68
+
69
+ for img, bboxes, density_map, ids, gt_bboxes, scaling_factor, padwh in test_loader:
70
+ img = img.to(device)
71
+ bboxes = bboxes.to(device)
72
+ gt_bboxes = gt_bboxes.to(device)
73
+
74
+ outputs, ref_points, _, _, masks = model(img, bboxes)
75
+
76
+ w, h = img.shape[-1], img.shape[-2]
77
+ losses = []
78
+ num_objects_gt = []
79
+ num_objects_pred = []
80
+ nms_bboxes = []
81
+ nms_scores = []
82
+ nms_masks = []
83
+ for idx in range(img.shape[0]):
84
+
85
+ thr = 1/0.11
86
+
87
+ if len(outputs[idx]['pred_boxes'][-1]) == 0:
88
+ nms_bboxes.append(torch.zeros((0, 4)))
89
+ nms_scores.append(torch.zeros((0)))
90
+ num_objects_pred.append(0)
91
+
92
+ else:
93
+
94
+ # threshold and NMS
95
+ v = outputs[idx]["box_v"]
96
+ v_thr = v.max() / thr
97
+ mask = v > v_thr
98
+ keep = ops.nms(
99
+ outputs[idx]["pred_boxes"][mask],
100
+ v[mask],
101
+ 0.5,
102
+ )
103
+ boxes = outputs[idx]["pred_boxes"][mask][keep]
104
+ boxes = torch.clamp(boxes, 0, 1)
105
+ scores = outputs[idx]["scores"][mask][keep]
106
+
107
+ # remove bboxes in padded area
108
+ maxw = (img.shape[-1] - padwh[idx][0]).to(device)
109
+ maxh = (img.shape[-2] - padwh[idx][1]).to(device)
110
+ center = (boxes[:, :2] + boxes[:, 2:]) / 2
111
+ valid = (center[:, 0] * h < maxw) & (center[:, 1] * w < maxh)
112
+ scores = scores[valid]
113
+ boxes = boxes[valid]
114
+
115
+ nms_bboxes.append(boxes)
116
+ nms_scores.append(scores)
117
+ num_objects_pred.append(len(boxes))
118
+
119
+ if False:
120
+ from matplotlib import pyplot as plt
121
+ fig1 = plt.figure(figsize=(8, 8))
122
+ ((ax1_11, ax1_12), (ax1_21, ax1_22)) = fig1.subplots(2, 2)
123
+ fig1.tight_layout(pad=2.5)
124
+ img_ = np.array((img).cpu()[idx].permute(1, 2, 0))
125
+ img_ = img_ - np.min(img_)
126
+ img_ = img_ / np.max(img_)
127
+ ax1_11.imshow(img_)
128
+ ax1_11.set_title("Input", fontsize=8)
129
+ bboxes_ = np.array(bboxes.cpu())[idx]
130
+ for i in range(3):
131
+ ax1_11.plot([bboxes_[i][0], bboxes_[i][0], bboxes_[i][2], bboxes_[i][2], bboxes_[i][0]],
132
+ [bboxes_[i][1], bboxes_[i][3], bboxes_[i][3], bboxes_[i][1], bboxes_[i][1]], c='r')
133
+ ax1_12.imshow(img_)
134
+ ax1_12.set_title("gt bboxes", fontsize=8)
135
+ target_bboxes = gt_bboxes[idx][torch.logical_not((gt_bboxes[idx] == 0).all(dim=1))]
136
+ bboxes_ = ((target_bboxes)).detach().cpu()
137
+ for i in range(len(bboxes_)):
138
+ ax1_12.plot([bboxes_[i][0], bboxes_[i][0], bboxes_[i][2], bboxes_[i][2], bboxes_[i][0]],
139
+ [bboxes_[i][1], bboxes_[i][3], bboxes_[i][3], bboxes_[i][1], bboxes_[i][1]], c='g')
140
+ ax1_21.imshow(img_)
141
+
142
+ bboxes_pred = nms_bboxes[idx]
143
+ bboxes_ = ((bboxes_pred * img_.shape[0])).detach().cpu()
144
+ for i in range(len(bboxes_)):
145
+ ax1_21.plot([bboxes_[i][0], bboxes_[i][0], bboxes_[i][2], bboxes_[i][2], bboxes_[i][0]],
146
+ [bboxes_[i][1], bboxes_[i][3], bboxes_[i][3], bboxes_[i][1], bboxes_[i][1]],
147
+ c='orange', linewidth=0.5)
148
+ ax1_21.set_title("#GT-#PRED=" + str(len(target_bboxes) - len(bboxes_pred)))
149
+ from torchvision import transforms as T
150
+ res = T.Resize((1024, 1024))
151
+ ax1_21.imshow(res(centerness).detach().cpu()[idx][0], alpha=0.6)
152
+ plt.savefig(test.image_names[ids[idx].item()], dpi=200)
153
+ plt.close()
154
+
155
+ for idx in range(img.shape[0]):
156
+ img_info = {
157
+ "id": test.map_img_name_to_ori_id()[test.image_names[ids[idx].item()]],
158
+ "file_name": "None",
159
+ }
160
+ bboxes = ops.box_convert(nms_bboxes[idx], 'xyxy', 'xywh')
161
+ bboxes = bboxes * img.shape[-1] / scaling_factor[idx]
162
+ for idxi in range(len(nms_bboxes[idx])):
163
+ box = bboxes[idxi].detach().cpu()
164
+ anno = {
165
+ "id": anno_id,
166
+ "image_id": test.map_img_name_to_ori_id()[test.image_names[ids[idx].item()]],
167
+ "area": int((box[2] * box[3]).item()),
168
+ "bbox": [int(box[0].item()), int(box[1].item()), int(box[2].item()), int(box[3].item())],
169
+ "category_id": 1,
170
+ "score": float(nms_scores[idx][idxi].item()),
171
+ }
172
+ anno_id += 1
173
+ predictions["annotations"].append(anno)
174
+ predictions["images"].append(img_info)
175
+ num_objects_gt = density_map.flatten(1).sum(dim=1)
176
+ num_objects_pred = torch.tensor(num_objects_pred)
177
+ ae += torch.abs(
178
+ num_objects_gt - num_objects_pred
179
+ ).sum()
180
+ se += torch.pow(
181
+ num_objects_gt - num_objects_pred, 2
182
+ ).sum()
183
+ print(
184
+ f"{split.capitalize()} set",
185
+ f"MAE: {ae.item() / len(test):.2f}",
186
+ f"RMSE: {torch.sqrt(se / len(test)).item():.2f}",
187
+ )
188
+
189
+ with open("geco2_" + split + ".json", "w") as handle:
190
+ json.dump(predictions, handle)
191
+
192
+
193
+ if __name__ == '__main__':
194
+ parser = argparse.ArgumentParser('GECO2', parents=[get_argparser()])
195
+ args = parser.parse_args()
196
+ print(args)
197
+ print("model_name: ", args.model_name)
198
+ evaluate(args)