itzjunayed commited on
Commit
5afef5d
·
verified ·
1 Parent(s): eaa26d4

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +196 -0
app.py ADDED
@@ -0,0 +1,196 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import shutil
3
+ from pathlib import Path
4
+ import pandas as pd
5
+ import gradio as gr
6
+ import torch
7
+ import numpy as np
8
+ from PIL import Image as PILImage
9
+ from detectron2 import model_zoo
10
+ from detectron2.config import get_cfg
11
+ from detectron2.engine import DefaultPredictor
12
+ from detectron2.data import DatasetCatalog, MetadataCatalog
13
+ from detectron2.utils.visualizer import ColorMode, Visualizer
14
+ from tqdm import tqdm
15
+ import uuid
16
+ import cv2
17
+ import pickle
18
+ from math import ceil
19
+ from typing import List, Dict, Any
20
+ from numpy import ndarray
21
+
22
+ def get_vinbigdata_dicts_test(imgdir: Path, test_meta: pd.DataFrame, use_cache: bool = True, debug: bool = False):
23
+ debug_str = f"_debug{int(debug)}"
24
+ cache_path = Path(".") / f"dataset_dicts_cache_test{debug_str}.pkl"
25
+ if not use_cache or not cache_path.exists():
26
+ if debug:
27
+ test_meta = test_meta.iloc[:500] # For debug
28
+
29
+ # Load 1 image to get image size
30
+ image_id = test_meta.loc[0, "image_id"]
31
+ image_path = imgdir / f"{image_id}.png"
32
+ image = cv2.imread(str(image_path))
33
+ resized_height, resized_width, ch = image.shape
34
+
35
+ dataset_dicts = []
36
+ for index, test_meta_row in tqdm(test_meta.iterrows(), total=len(test_meta)):
37
+ record = {}
38
+ image_id, height, width = test_meta_row.values
39
+ filename = imgdir / f"{image_id}.png"
40
+ record["file_name"] = str(filename)
41
+ record["image_id"] = image_id
42
+ record["height"] = resized_height
43
+ record["width"] = resized_width
44
+ dataset_dicts.append(record)
45
+
46
+ with open(cache_path, mode="wb") as f:
47
+ pickle.dump(dataset_dicts, f)
48
+
49
+ with open(cache_path, mode="rb") as f:
50
+ dataset_dicts = pickle.load(f)
51
+ return dataset_dicts
52
+
53
+ def format_pred(labels: ndarray, boxes: ndarray, scores: ndarray) -> str:
54
+ pred_strings = []
55
+ for label, score, bbox in zip(labels, scores, boxes):
56
+ xmin, ymin, xmax, ymax = bbox.astype(np.int64)
57
+ pred_strings.append(f"{label} {score} {xmin} {ymin} {xmax} {ymax}")
58
+ return " ".join(pred_strings)
59
+
60
+ def predict_batch(predictor: DefaultPredictor, im_list: List[ndarray]) -> List:
61
+ with torch.no_grad():
62
+ inputs_list = []
63
+ for original_image in im_list:
64
+ if predictor.input_format == "RGB":
65
+ original_image = original_image[:, :, ::-1]
66
+ height, width = original_image.shape[:2]
67
+ image = torch.as_tensor(original_image.astype("float32").transpose(2, 0, 1))
68
+ inputs = {"image": image, "height": height, "width": width}
69
+ inputs_list.append(inputs)
70
+ predictions = predictor.model(inputs_list)
71
+ return predictions
72
+
73
+ def csv_create(new_image_path, image_id):
74
+ image = PILImage.open(new_image_path)
75
+ width, height = image.size
76
+ directory = os.path.dirname(new_image_path)
77
+
78
+ sample_submission_data = {
79
+ 'image_id': [image_id],
80
+ 'PredictionString': ['14 1 0 0 1 1']
81
+ }
82
+ sample_submission_df = pd.DataFrame(sample_submission_data)
83
+ sample_submission_path = os.path.join(directory, 'sample_submission.csv')
84
+ sample_submission_df.to_csv(sample_submission_path, index=False)
85
+
86
+ test_meta_data = {
87
+ 'image_id': [image_id],
88
+ 'dim0': [width],
89
+ 'dim1': [height]
90
+ }
91
+ test_meta_df = pd.DataFrame(test_meta_data)
92
+ test_meta_path = os.path.join(directory, 'test_meta.csv')
93
+ test_meta_df.to_csv(test_meta_path, index=False)
94
+
95
+ return sample_submission_path, test_meta_path
96
+
97
+ def prediction(image_id_main, local_image_path, model_path):
98
+ thing_classes = [
99
+ "Aortic enlargement", "Atelectasis", "Calcification", "Cardiomegaly", "Consolidation",
100
+ "ILD", "Infiltration", "Lung Opacity", "Nodule/Mass", "Other lesion", "Pleural effusion",
101
+ "Pleural thickening", "Pneumothorax", "Pulmonary fibrosis"
102
+ ]
103
+
104
+ outdir = 'result_images'
105
+ os.makedirs(outdir, exist_ok=True)
106
+
107
+ imgdir = f'processed_images_{image_id_main}'
108
+ os.makedirs(imgdir, exist_ok=True)
109
+ shutil.copy(local_image_path, imgdir)
110
+ new_image_path = os.path.join(imgdir, os.path.basename(local_image_path))
111
+
112
+ sample_submission, test_meta = csv_create(new_image_path, image_id_main)
113
+
114
+ cfg = get_cfg()
115
+ cfg.OUTPUT_DIR = outdir
116
+ cfg.merge_from_file(model_zoo.get_config_file("COCO-Detection/faster_rcnn_R_50_FPN_3x.yaml"))
117
+ cfg.DATASETS.TEST = ()
118
+ cfg.DATALOADER.NUM_WORKERS = 2
119
+ cfg.MODEL.WEIGHTS = model_zoo.get_checkpoint_url("COCO-Detection/faster_rcnn_R_50_FPN_3x.yaml")
120
+ cfg.SOLVER.IMS_PER_BATCH = 2
121
+ cfg.SOLVER.BASE_LR = 0.001
122
+ cfg.SOLVER.MAX_ITER = 30000
123
+ cfg.MODEL.ROI_HEADS.BATCH_SIZE_PER_IMAGE = 512
124
+ cfg.MODEL.ROI_HEADS.NUM_CLASSES = len(thing_classes)
125
+
126
+ cfg.MODEL.WEIGHTS = model_path
127
+ cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = 0.0
128
+ predictor = DefaultPredictor(cfg)
129
+
130
+ unique_id = f"bigdata2_{uuid.uuid4().hex[:8]}"
131
+ DatasetCatalog.register(
132
+ unique_id, lambda: get_vinbigdata_dicts_test(imgdir, pd.read_csv(test_meta), debug=False)
133
+ )
134
+ MetadataCatalog.get(unique_id).set(thing_classes=thing_classes)
135
+ metadata = MetadataCatalog.get(unique_id)
136
+ dataset_dicts = get_vinbigdata_dicts_test(imgdir, pd.read_csv(test_meta), debug=False)
137
+
138
+ results_list = []
139
+ batch_size = 4
140
+
141
+ for i in tqdm(range(ceil(len(dataset_dicts) / batch_size))):
142
+ inds = list(range(batch_size * i, min(batch_size * (i + 1), len(dataset_dicts))))
143
+ dataset_dicts_batch = [dataset_dicts[i] for i in inds]
144
+ im_list = [cv2.imread(d["file_name"]) for d in dataset_dicts_batch]
145
+ outputs_list = predict_batch(predictor, im_list)
146
+
147
+ for im, outputs, d in zip(im_list, outputs_list, dataset_dicts_batch):
148
+ image_id, dim0, dim1 = pd.read_csv(test_meta).iloc[i].values
149
+
150
+ instances = outputs["instances"]
151
+ if len(instances) == 0:
152
+ result = {"image_id": image_id, "PredictionString": "14 1.0 0 0 1 1"}
153
+ else:
154
+ fields: Dict[str, Any] = instances.get_fields()
155
+ pred_classes = fields["pred_classes"]
156
+ pred_scores = fields["scores"]
157
+ pred_boxes = fields["pred_boxes"].tensor
158
+
159
+ h_ratio = dim0 / im.shape[0]
160
+ w_ratio = dim1 / im.shape[1]
161
+ pred_boxes[:, [0, 2]] *= w_ratio
162
+ pred_boxes[:, [1, 3]] *= h_ratio
163
+
164
+ pred_classes_array = pred_classes.cpu().numpy()
165
+ pred_boxes_array = pred_boxes.cpu().numpy()
166
+ pred_scores_array = pred_scores.cpu().numpy()
167
+
168
+ result = {
169
+ "image_id": image_id,
170
+ "PredictionString": format_pred(
171
+ pred_classes_array, pred_boxes_array, pred_scores_array
172
+ ),
173
+ }
174
+ results_list.append(result)
175
+
176
+ submission_det = pd.DataFrame(results_list, columns=['image_id', 'PredictionString'])
177
+ submission_det_path = os.path.join(outdir, "submission_det.csv")
178
+ submission_det.to_csv(submission_det_path, index=False)
179
+
180
+ return submission_det_path
181
+
182
+ def process_image(image, image_id, model_path='path_to_your_model.pth'):
183
+ local_image_path = f'./{image_id}.png'
184
+ image.save(local_image_path)
185
+ submission_det_path = prediction(image_id, local_image_path, model_path)
186
+ return submission_det_path
187
+
188
+ iface = gr.Interface(
189
+ fn=process_image,
190
+ inputs=[gr.Image(type="pil"), gr.Textbox(label="Image ID")],
191
+ outputs=gr.File(label="submission_det.csv"),
192
+ title="ECG Image Processing",
193
+ description="Upload an image and get the resulting CSV file."
194
+ )
195
+
196
+ iface.launch()