Spaces:
Runtime error
Runtime error
Create app.py
Browse files
app.py
ADDED
@@ -0,0 +1,196 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import shutil
|
3 |
+
from pathlib import Path
|
4 |
+
import pandas as pd
|
5 |
+
import gradio as gr
|
6 |
+
import torch
|
7 |
+
import numpy as np
|
8 |
+
from PIL import Image as PILImage
|
9 |
+
from detectron2 import model_zoo
|
10 |
+
from detectron2.config import get_cfg
|
11 |
+
from detectron2.engine import DefaultPredictor
|
12 |
+
from detectron2.data import DatasetCatalog, MetadataCatalog
|
13 |
+
from detectron2.utils.visualizer import ColorMode, Visualizer
|
14 |
+
from tqdm import tqdm
|
15 |
+
import uuid
|
16 |
+
import cv2
|
17 |
+
import pickle
|
18 |
+
from math import ceil
|
19 |
+
from typing import List, Dict, Any
|
20 |
+
from numpy import ndarray
|
21 |
+
|
22 |
+
def get_vinbigdata_dicts_test(imgdir: Path, test_meta: pd.DataFrame, use_cache: bool = True, debug: bool = False):
|
23 |
+
debug_str = f"_debug{int(debug)}"
|
24 |
+
cache_path = Path(".") / f"dataset_dicts_cache_test{debug_str}.pkl"
|
25 |
+
if not use_cache or not cache_path.exists():
|
26 |
+
if debug:
|
27 |
+
test_meta = test_meta.iloc[:500] # For debug
|
28 |
+
|
29 |
+
# Load 1 image to get image size
|
30 |
+
image_id = test_meta.loc[0, "image_id"]
|
31 |
+
image_path = imgdir / f"{image_id}.png"
|
32 |
+
image = cv2.imread(str(image_path))
|
33 |
+
resized_height, resized_width, ch = image.shape
|
34 |
+
|
35 |
+
dataset_dicts = []
|
36 |
+
for index, test_meta_row in tqdm(test_meta.iterrows(), total=len(test_meta)):
|
37 |
+
record = {}
|
38 |
+
image_id, height, width = test_meta_row.values
|
39 |
+
filename = imgdir / f"{image_id}.png"
|
40 |
+
record["file_name"] = str(filename)
|
41 |
+
record["image_id"] = image_id
|
42 |
+
record["height"] = resized_height
|
43 |
+
record["width"] = resized_width
|
44 |
+
dataset_dicts.append(record)
|
45 |
+
|
46 |
+
with open(cache_path, mode="wb") as f:
|
47 |
+
pickle.dump(dataset_dicts, f)
|
48 |
+
|
49 |
+
with open(cache_path, mode="rb") as f:
|
50 |
+
dataset_dicts = pickle.load(f)
|
51 |
+
return dataset_dicts
|
52 |
+
|
53 |
+
def format_pred(labels: ndarray, boxes: ndarray, scores: ndarray) -> str:
|
54 |
+
pred_strings = []
|
55 |
+
for label, score, bbox in zip(labels, scores, boxes):
|
56 |
+
xmin, ymin, xmax, ymax = bbox.astype(np.int64)
|
57 |
+
pred_strings.append(f"{label} {score} {xmin} {ymin} {xmax} {ymax}")
|
58 |
+
return " ".join(pred_strings)
|
59 |
+
|
60 |
+
def predict_batch(predictor: DefaultPredictor, im_list: List[ndarray]) -> List:
|
61 |
+
with torch.no_grad():
|
62 |
+
inputs_list = []
|
63 |
+
for original_image in im_list:
|
64 |
+
if predictor.input_format == "RGB":
|
65 |
+
original_image = original_image[:, :, ::-1]
|
66 |
+
height, width = original_image.shape[:2]
|
67 |
+
image = torch.as_tensor(original_image.astype("float32").transpose(2, 0, 1))
|
68 |
+
inputs = {"image": image, "height": height, "width": width}
|
69 |
+
inputs_list.append(inputs)
|
70 |
+
predictions = predictor.model(inputs_list)
|
71 |
+
return predictions
|
72 |
+
|
73 |
+
def csv_create(new_image_path, image_id):
|
74 |
+
image = PILImage.open(new_image_path)
|
75 |
+
width, height = image.size
|
76 |
+
directory = os.path.dirname(new_image_path)
|
77 |
+
|
78 |
+
sample_submission_data = {
|
79 |
+
'image_id': [image_id],
|
80 |
+
'PredictionString': ['14 1 0 0 1 1']
|
81 |
+
}
|
82 |
+
sample_submission_df = pd.DataFrame(sample_submission_data)
|
83 |
+
sample_submission_path = os.path.join(directory, 'sample_submission.csv')
|
84 |
+
sample_submission_df.to_csv(sample_submission_path, index=False)
|
85 |
+
|
86 |
+
test_meta_data = {
|
87 |
+
'image_id': [image_id],
|
88 |
+
'dim0': [width],
|
89 |
+
'dim1': [height]
|
90 |
+
}
|
91 |
+
test_meta_df = pd.DataFrame(test_meta_data)
|
92 |
+
test_meta_path = os.path.join(directory, 'test_meta.csv')
|
93 |
+
test_meta_df.to_csv(test_meta_path, index=False)
|
94 |
+
|
95 |
+
return sample_submission_path, test_meta_path
|
96 |
+
|
97 |
+
def prediction(image_id_main, local_image_path, model_path):
|
98 |
+
thing_classes = [
|
99 |
+
"Aortic enlargement", "Atelectasis", "Calcification", "Cardiomegaly", "Consolidation",
|
100 |
+
"ILD", "Infiltration", "Lung Opacity", "Nodule/Mass", "Other lesion", "Pleural effusion",
|
101 |
+
"Pleural thickening", "Pneumothorax", "Pulmonary fibrosis"
|
102 |
+
]
|
103 |
+
|
104 |
+
outdir = 'result_images'
|
105 |
+
os.makedirs(outdir, exist_ok=True)
|
106 |
+
|
107 |
+
imgdir = f'processed_images_{image_id_main}'
|
108 |
+
os.makedirs(imgdir, exist_ok=True)
|
109 |
+
shutil.copy(local_image_path, imgdir)
|
110 |
+
new_image_path = os.path.join(imgdir, os.path.basename(local_image_path))
|
111 |
+
|
112 |
+
sample_submission, test_meta = csv_create(new_image_path, image_id_main)
|
113 |
+
|
114 |
+
cfg = get_cfg()
|
115 |
+
cfg.OUTPUT_DIR = outdir
|
116 |
+
cfg.merge_from_file(model_zoo.get_config_file("COCO-Detection/faster_rcnn_R_50_FPN_3x.yaml"))
|
117 |
+
cfg.DATASETS.TEST = ()
|
118 |
+
cfg.DATALOADER.NUM_WORKERS = 2
|
119 |
+
cfg.MODEL.WEIGHTS = model_zoo.get_checkpoint_url("COCO-Detection/faster_rcnn_R_50_FPN_3x.yaml")
|
120 |
+
cfg.SOLVER.IMS_PER_BATCH = 2
|
121 |
+
cfg.SOLVER.BASE_LR = 0.001
|
122 |
+
cfg.SOLVER.MAX_ITER = 30000
|
123 |
+
cfg.MODEL.ROI_HEADS.BATCH_SIZE_PER_IMAGE = 512
|
124 |
+
cfg.MODEL.ROI_HEADS.NUM_CLASSES = len(thing_classes)
|
125 |
+
|
126 |
+
cfg.MODEL.WEIGHTS = model_path
|
127 |
+
cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = 0.0
|
128 |
+
predictor = DefaultPredictor(cfg)
|
129 |
+
|
130 |
+
unique_id = f"bigdata2_{uuid.uuid4().hex[:8]}"
|
131 |
+
DatasetCatalog.register(
|
132 |
+
unique_id, lambda: get_vinbigdata_dicts_test(imgdir, pd.read_csv(test_meta), debug=False)
|
133 |
+
)
|
134 |
+
MetadataCatalog.get(unique_id).set(thing_classes=thing_classes)
|
135 |
+
metadata = MetadataCatalog.get(unique_id)
|
136 |
+
dataset_dicts = get_vinbigdata_dicts_test(imgdir, pd.read_csv(test_meta), debug=False)
|
137 |
+
|
138 |
+
results_list = []
|
139 |
+
batch_size = 4
|
140 |
+
|
141 |
+
for i in tqdm(range(ceil(len(dataset_dicts) / batch_size))):
|
142 |
+
inds = list(range(batch_size * i, min(batch_size * (i + 1), len(dataset_dicts))))
|
143 |
+
dataset_dicts_batch = [dataset_dicts[i] for i in inds]
|
144 |
+
im_list = [cv2.imread(d["file_name"]) for d in dataset_dicts_batch]
|
145 |
+
outputs_list = predict_batch(predictor, im_list)
|
146 |
+
|
147 |
+
for im, outputs, d in zip(im_list, outputs_list, dataset_dicts_batch):
|
148 |
+
image_id, dim0, dim1 = pd.read_csv(test_meta).iloc[i].values
|
149 |
+
|
150 |
+
instances = outputs["instances"]
|
151 |
+
if len(instances) == 0:
|
152 |
+
result = {"image_id": image_id, "PredictionString": "14 1.0 0 0 1 1"}
|
153 |
+
else:
|
154 |
+
fields: Dict[str, Any] = instances.get_fields()
|
155 |
+
pred_classes = fields["pred_classes"]
|
156 |
+
pred_scores = fields["scores"]
|
157 |
+
pred_boxes = fields["pred_boxes"].tensor
|
158 |
+
|
159 |
+
h_ratio = dim0 / im.shape[0]
|
160 |
+
w_ratio = dim1 / im.shape[1]
|
161 |
+
pred_boxes[:, [0, 2]] *= w_ratio
|
162 |
+
pred_boxes[:, [1, 3]] *= h_ratio
|
163 |
+
|
164 |
+
pred_classes_array = pred_classes.cpu().numpy()
|
165 |
+
pred_boxes_array = pred_boxes.cpu().numpy()
|
166 |
+
pred_scores_array = pred_scores.cpu().numpy()
|
167 |
+
|
168 |
+
result = {
|
169 |
+
"image_id": image_id,
|
170 |
+
"PredictionString": format_pred(
|
171 |
+
pred_classes_array, pred_boxes_array, pred_scores_array
|
172 |
+
),
|
173 |
+
}
|
174 |
+
results_list.append(result)
|
175 |
+
|
176 |
+
submission_det = pd.DataFrame(results_list, columns=['image_id', 'PredictionString'])
|
177 |
+
submission_det_path = os.path.join(outdir, "submission_det.csv")
|
178 |
+
submission_det.to_csv(submission_det_path, index=False)
|
179 |
+
|
180 |
+
return submission_det_path
|
181 |
+
|
182 |
+
def process_image(image, image_id, model_path='path_to_your_model.pth'):
|
183 |
+
local_image_path = f'./{image_id}.png'
|
184 |
+
image.save(local_image_path)
|
185 |
+
submission_det_path = prediction(image_id, local_image_path, model_path)
|
186 |
+
return submission_det_path
|
187 |
+
|
188 |
+
iface = gr.Interface(
|
189 |
+
fn=process_image,
|
190 |
+
inputs=[gr.Image(type="pil"), gr.Textbox(label="Image ID")],
|
191 |
+
outputs=gr.File(label="submission_det.csv"),
|
192 |
+
title="ECG Image Processing",
|
193 |
+
description="Upload an image and get the resulting CSV file."
|
194 |
+
)
|
195 |
+
|
196 |
+
iface.launch()
|