microScan / utils /annotations.py
crazyscientist1's picture
initial commit
d70f24c
raw
history blame contribute delete
No virus
5.49 kB
import numpy as np
import cv2
from skimage import transform
import matplotlib.pyplot as plt
from PIL import Image
# def inference_annotations(
# outputs, detection_threshold, classes,
# colors, orig_image
# ):
# boxes = outputs[0]['boxes'].data.numpy()
# scores = outputs[0]['scores'].data.numpy()
# # Filter out boxes according to `detection_threshold`.
# boxes = boxes[scores >= detection_threshold].astype(np.int32)
# draw_boxes = boxes.copy()
# # Get all the predicited class names.
# pred_classes = [classes[i] for i in outputs[0]['labels'].cpu().numpy()]
# lw = max(round(sum(orig_image.shape) / 2 * 0.003), 2) # Line width.
# tf = max(lw - 1, 1) # Font thickness.
# # Draw the bounding boxes and write the class name on top of it.
# for j, box in enumerate(draw_boxes):
# p1, p2 = (int(box[0]), int(box[1])), (int(box[2]), int(box[3]))
# class_name = pred_classes[j]
# color = colors[classes.index(class_name)]
# cv2.rectangle(
# orig_image,
# p1, p2,
# color=(0,0,255),
# thickness=lw,
# lineType=cv2.LINE_AA
# )
# # For filled rectangle.
# w, h = cv2.getTextSize(
# class_name,
# 0,
# fontScale=lw / 3,
# thickness=tf
# )[0] # text width, height
# outside = p1[1] - h >= 3
# p2 = p1[0] + w, p1[1] - h - 3 if outside else p1[1] + h + 3
# cv2.rectangle(
# orig_image,
# p1,
# p2,
# color=(0,0,255),
# thickness=-1,
# lineType=cv2.LINE_AA
# )
# cv2.putText(
# orig_image,
# class_name,
# (p1[0], p1[1] - 5 if outside else p1[1] + h + 2),
# cv2.FONT_HERSHEY_SIMPLEX,
# fontScale=lw / 3.8,
# color=(255, 255, 255),
# thickness=tf,
# lineType=cv2.LINE_AA
# )
# return orig_image
def CNNpostAnnotations(
outputs, detection_threshold, classes,
colors, orig_image, CNN
):
imgCellVals = []
mod = orig_image.copy()
boxes = outputs[0]['boxes'].data.numpy()
scores = outputs[0]['scores'].data.numpy()
# Filter out boxes according to `detection_threshold`.
boxes = boxes[scores >= detection_threshold].astype(np.int32)
draw_boxes = boxes.copy()
# Get all the predicited class names.
pred_classes = [classes[i] for i in outputs[0]['labels'].cpu().numpy()]
lw = max(round(sum(orig_image.shape) / 2 * 0.003), 2) # Line width.
tf = max(lw - 1, 1) # Font thickness.
# Draw the bounding boxes and write the class name on top of it.
for j, box in enumerate(draw_boxes):
p1, p2 = (int(box[0]), int(box[1])), (int(box[2]), int(box[3]))
run = orig_image[int(box[1]):int(box[3]), int(box[0]):int(box[2])]
imgr = Image.fromarray(run, 'RGB')
np_image = run.astype('float32')/255
np_image = transform.resize(np_image, (200, 200, 3))
np_image = np.expand_dims(np_image, axis=0)
class_index = {0: 'Gametocyte', 1: 'RBC', 2: 'Ring', 3: 'Schizont', 4: 'Trophozoite'}
classText = class_index[np.argmax(CNN.predict(np_image))]
class_name = pred_classes[j]
cv2.rectangle(
mod,
p1, p2,
color=(0,0,255),
thickness=lw,
lineType=cv2.LINE_AA
)
# For filled rectangle.
w, h = cv2.getTextSize(
classText,
0,
fontScale=lw / 3,
thickness=tf
)[0] # text width, height
outside = p1[1] - h >= 3
p2 = p1[0] + w, p1[1] - h - 3 if outside else p1[1] + h + 3
cv2.rectangle(
mod,
p1,
p2,
color=(0,0,255),
thickness=-1,
lineType=cv2.LINE_AA
)
cv2.putText(
mod,
classText,
(p1[0], p1[1] - 5 if outside else p1[1] + h + 2),
cv2.FONT_HERSHEY_SIMPLEX,
fontScale=lw / 3.8,
color=(255, 255, 255),
thickness=tf,
lineType=cv2.LINE_AA
)
imgCellVals.append([imgr, classText])
return mod, imgCellVals
def draw_text(
img,
text,
font=cv2.FONT_HERSHEY_SIMPLEX,
pos=(0, 0),
font_scale=1,
font_thickness=2,
text_color=(0, 255, 0),
text_color_bg=(0, 0, 0),
):
offset = (5, 5)
x, y = pos
text_size, _ = cv2.getTextSize(text, font, font_scale, font_thickness)
text_w, text_h = text_size
rec_start = tuple(x - y for x, y in zip(pos, offset))
rec_end = tuple(x + y for x, y in zip((x + text_w, y + text_h), offset))
cv2.rectangle(img, rec_start, rec_end, text_color_bg, -1)
cv2.putText(
img,
text,
(x, int(y + text_h + font_scale - 1)),
font,
font_scale,
text_color,
font_thickness,
cv2.LINE_AA,
)
return img
def annotate_fps(orig_image, fps_text):
draw_text(
orig_image,
f"FPS: {fps_text:0.1f}",
pos=(20, 20),
font_scale=1.0,
text_color=(204, 85, 17),
text_color_bg=(255, 255, 255),
font_thickness=2,
)
return orig_image