Spaces:
Runtime error
Runtime error
import datetime | |
import os | |
import torch | |
import matplotlib | |
matplotlib.use('Agg') | |
import scipy.signal | |
from matplotlib import pyplot as plt | |
from torch.utils.tensorboard import SummaryWriter | |
import shutil | |
import numpy as np | |
from PIL import Image | |
from tqdm import tqdm | |
from .utils import cvtColor, preprocess_input, resize_image | |
from .utils_bbox import DecodeBox | |
from .utils_map import get_coco_map, get_map | |
class LossHistory(): | |
def __init__(self, log_dir, model, input_shape): | |
self.log_dir = log_dir | |
self.losses = [] | |
self.val_loss = [] | |
os.makedirs(self.log_dir) | |
self.writer = SummaryWriter(self.log_dir) | |
try: | |
dummy_input = torch.randn(2, 3, input_shape[0], input_shape[1]) | |
self.writer.add_graph(model, dummy_input) | |
except: | |
pass | |
def append_loss(self, epoch, loss, val_loss): | |
if not os.path.exists(self.log_dir): | |
os.makedirs(self.log_dir) | |
self.losses.append(loss) | |
self.val_loss.append(val_loss) | |
with open(os.path.join(self.log_dir, "epoch_loss.txt"), 'a') as f: | |
f.write(str(loss)) | |
f.write("\n") | |
with open(os.path.join(self.log_dir, "epoch_val_loss.txt"), 'a') as f: | |
f.write(str(val_loss)) | |
f.write("\n") | |
self.writer.add_scalar('loss', loss, epoch) | |
self.writer.add_scalar('val_loss', val_loss, epoch) | |
self.loss_plot() | |
def loss_plot(self): | |
iters = range(len(self.losses)) | |
plt.figure() | |
plt.plot(iters, self.losses, 'red', linewidth = 2, label='train loss') | |
plt.plot(iters, self.val_loss, 'coral', linewidth = 2, label='val loss') | |
try: | |
if len(self.losses) < 25: | |
num = 5 | |
else: | |
num = 15 | |
plt.plot(iters, scipy.signal.savgol_filter(self.losses, num, 3), 'green', linestyle = '--', linewidth = 2, label='smooth train loss') | |
except: | |
pass | |
plt.grid(True) | |
plt.xlabel('Epoch') | |
plt.ylabel('Loss') | |
plt.legend(loc="upper right") | |
plt.savefig(os.path.join(self.log_dir, "epoch_loss.png")) | |
plt.cla() | |
plt.close("all") | |
class EvalCallback(): | |
def __init__(self, net, input_shape, anchors, anchors_mask, class_names, num_classes, val_lines, log_dir, cuda, \ | |
map_out_path=".temp_map_out", max_boxes=100, confidence=0.05, nms_iou=0.5, letterbox_image=True, MINOVERLAP=0.5, eval_flag=True, period=1): | |
super(EvalCallback, self).__init__() | |
self.net = net | |
self.input_shape = input_shape | |
self.anchors = anchors | |
self.anchors_mask = anchors_mask | |
self.class_names = class_names | |
self.num_classes = num_classes | |
self.val_lines = val_lines | |
self.log_dir = log_dir | |
self.cuda = cuda | |
self.map_out_path = map_out_path | |
self.max_boxes = max_boxes | |
self.confidence = confidence | |
self.nms_iou = nms_iou | |
self.letterbox_image = letterbox_image | |
self.MINOVERLAP = MINOVERLAP | |
self.eval_flag = eval_flag | |
self.period = period | |
self.bbox_util = DecodeBox(self.anchors, self.num_classes, (self.input_shape[0], self.input_shape[1]), self.anchors_mask) | |
self.maps = [0] | |
self.epoches = [0] | |
if self.eval_flag: | |
with open(os.path.join(self.log_dir, "epoch_map.txt"), 'a') as f: | |
f.write(str(0)) | |
f.write("\n") | |
def get_map_txt(self, image_id, image, class_names, map_out_path): | |
f = open(os.path.join(map_out_path, "detection-results/"+image_id+".txt"), "w", encoding='utf-8') | |
image_shape = np.array(np.shape(image)[0:2]) | |
image = cvtColor(image) | |
image_data = resize_image(image, (self.input_shape[1], self.input_shape[0]), self.letterbox_image) | |
image_data = np.expand_dims(np.transpose(preprocess_input(np.array(image_data, dtype='float32')), (2, 0, 1)), 0) | |
with torch.no_grad(): | |
images = torch.from_numpy(image_data) | |
if self.cuda: | |
images = images.cuda() | |
outputs = self.net(images) | |
outputs = self.bbox_util.decode_box(outputs) | |
results = self.bbox_util.non_max_suppression(torch.cat(outputs, 1), self.num_classes, self.input_shape, | |
image_shape, self.letterbox_image, conf_thres = self.confidence, nms_thres = self.nms_iou) | |
if results[0] is None: | |
return | |
top_label = np.array(results[0][:, 6], dtype = 'int32') | |
top_conf = results[0][:, 4] * results[0][:, 5] | |
top_boxes = results[0][:, :4] | |
top_100 = np.argsort(top_conf)[::-1][:self.max_boxes] | |
top_boxes = top_boxes[top_100] | |
top_conf = top_conf[top_100] | |
top_label = top_label[top_100] | |
for i, c in list(enumerate(top_label)): | |
predicted_class = self.class_names[int(c)] | |
box = top_boxes[i] | |
score = str(top_conf[i]) | |
top, left, bottom, right = box | |
if predicted_class not in class_names: | |
continue | |
f.write("%s %s %s %s %s %s\n" % (predicted_class, score[:6], str(int(left)), str(int(top)), str(int(right)),str(int(bottom)))) | |
f.close() | |
return | |
def on_epoch_end(self, epoch, model_eval): | |
if epoch % self.period == 0 and self.eval_flag: | |
self.net = model_eval | |
if not os.path.exists(self.map_out_path): | |
os.makedirs(self.map_out_path) | |
if not os.path.exists(os.path.join(self.map_out_path, "ground-truth")): | |
os.makedirs(os.path.join(self.map_out_path, "ground-truth")) | |
if not os.path.exists(os.path.join(self.map_out_path, "detection-results")): | |
os.makedirs(os.path.join(self.map_out_path, "detection-results")) | |
print("Get map.") | |
for annotation_line in tqdm(self.val_lines): | |
line = annotation_line.split() | |
image_id = os.path.basename(line[0]).split('.')[0] | |
image = Image.open(line[0]) | |
gt_boxes = np.array([np.array(list(map(int,box.split(',')))) for box in line[1:]]) | |
self.get_map_txt(image_id, image, self.class_names, self.map_out_path) | |
with open(os.path.join(self.map_out_path, "ground-truth/"+image_id+".txt"), "w") as new_f: | |
for box in gt_boxes: | |
left, top, right, bottom, obj = box | |
obj_name = self.class_names[obj] | |
new_f.write("%s %s %s %s %s\n" % (obj_name, left, top, right, bottom)) | |
print("Calculate Map.") | |
try: | |
temp_map = get_coco_map(class_names = self.class_names, path = self.map_out_path)[1] | |
except: | |
temp_map = get_map(self.MINOVERLAP, False, path = self.map_out_path) | |
self.maps.append(temp_map) | |
self.epoches.append(epoch) | |
with open(os.path.join(self.log_dir, "epoch_map.txt"), 'a') as f: | |
f.write(str(temp_map)) | |
f.write("\n") | |
plt.figure() | |
plt.plot(self.epoches, self.maps, 'red', linewidth = 2, label='train map') | |
plt.grid(True) | |
plt.xlabel('Epoch') | |
plt.ylabel('Map %s'%str(self.MINOVERLAP)) | |
plt.title('A Map Curve') | |
plt.legend(loc="upper right") | |
plt.savefig(os.path.join(self.log_dir, "epoch_map.png")) | |
plt.cla() | |
plt.close("all") | |
print("Get map done.") | |
shutil.rmtree(self.map_out_path) | |