microScan / utils /annotations.py
crazyscientist1's picture
initial commit
d70f24c
raw
history blame
No virus
5.49 kB
import numpy as np
import cv2
from skimage import transform
import matplotlib.pyplot as plt
from PIL import Image
# def inference_annotations(
# outputs, detection_threshold, classes,
# colors, orig_image
# ):
# boxes = outputs[0]['boxes'].data.numpy()
# scores = outputs[0]['scores'].data.numpy()
# # Filter out boxes according to `detection_threshold`.
# boxes = boxes[scores >= detection_threshold].astype(np.int32)
# draw_boxes = boxes.copy()
# # Get all the predicited class names.
# pred_classes = [classes[i] for i in outputs[0]['labels'].cpu().numpy()]
# lw = max(round(sum(orig_image.shape) / 2 * 0.003), 2) # Line width.
# tf = max(lw - 1, 1) # Font thickness.
# # Draw the bounding boxes and write the class name on top of it.
# for j, box in enumerate(draw_boxes):
# p1, p2 = (int(box[0]), int(box[1])), (int(box[2]), int(box[3]))
# class_name = pred_classes[j]
# color = colors[classes.index(class_name)]
# cv2.rectangle(
# orig_image,
# p1, p2,
# color=(0,0,255),
# thickness=lw,
# lineType=cv2.LINE_AA
# )
# # For filled rectangle.
# w, h = cv2.getTextSize(
# class_name,
# 0,
# fontScale=lw / 3,
# thickness=tf
# )[0] # text width, height
# outside = p1[1] - h >= 3
# p2 = p1[0] + w, p1[1] - h - 3 if outside else p1[1] + h + 3
# cv2.rectangle(
# orig_image,
# p1,
# p2,
# color=(0,0,255),
# thickness=-1,
# lineType=cv2.LINE_AA
# )
# cv2.putText(
# orig_image,
# class_name,
# (p1[0], p1[1] - 5 if outside else p1[1] + h + 2),
# cv2.FONT_HERSHEY_SIMPLEX,
# fontScale=lw / 3.8,
# color=(255, 255, 255),
# thickness=tf,
# lineType=cv2.LINE_AA
# )
# return orig_image
def CNNpostAnnotations(
outputs, detection_threshold, classes,
colors, orig_image, CNN
):
imgCellVals = []
mod = orig_image.copy()
boxes = outputs[0]['boxes'].data.numpy()
scores = outputs[0]['scores'].data.numpy()
# Filter out boxes according to `detection_threshold`.
boxes = boxes[scores >= detection_threshold].astype(np.int32)
draw_boxes = boxes.copy()
# Get all the predicited class names.
pred_classes = [classes[i] for i in outputs[0]['labels'].cpu().numpy()]
lw = max(round(sum(orig_image.shape) / 2 * 0.003), 2) # Line width.
tf = max(lw - 1, 1) # Font thickness.
# Draw the bounding boxes and write the class name on top of it.
for j, box in enumerate(draw_boxes):
p1, p2 = (int(box[0]), int(box[1])), (int(box[2]), int(box[3]))
run = orig_image[int(box[1]):int(box[3]), int(box[0]):int(box[2])]
imgr = Image.fromarray(run, 'RGB')
np_image = run.astype('float32')/255
np_image = transform.resize(np_image, (200, 200, 3))
np_image = np.expand_dims(np_image, axis=0)
class_index = {0: 'Gametocyte', 1: 'RBC', 2: 'Ring', 3: 'Schizont', 4: 'Trophozoite'}
classText = class_index[np.argmax(CNN.predict(np_image))]
class_name = pred_classes[j]
cv2.rectangle(
mod,
p1, p2,
color=(0,0,255),
thickness=lw,
lineType=cv2.LINE_AA
)
# For filled rectangle.
w, h = cv2.getTextSize(
classText,
0,
fontScale=lw / 3,
thickness=tf
)[0] # text width, height
outside = p1[1] - h >= 3
p2 = p1[0] + w, p1[1] - h - 3 if outside else p1[1] + h + 3
cv2.rectangle(
mod,
p1,
p2,
color=(0,0,255),
thickness=-1,
lineType=cv2.LINE_AA
)
cv2.putText(
mod,
classText,
(p1[0], p1[1] - 5 if outside else p1[1] + h + 2),
cv2.FONT_HERSHEY_SIMPLEX,
fontScale=lw / 3.8,
color=(255, 255, 255),
thickness=tf,
lineType=cv2.LINE_AA
)
imgCellVals.append([imgr, classText])
return mod, imgCellVals
def draw_text(
img,
text,
font=cv2.FONT_HERSHEY_SIMPLEX,
pos=(0, 0),
font_scale=1,
font_thickness=2,
text_color=(0, 255, 0),
text_color_bg=(0, 0, 0),
):
offset = (5, 5)
x, y = pos
text_size, _ = cv2.getTextSize(text, font, font_scale, font_thickness)
text_w, text_h = text_size
rec_start = tuple(x - y for x, y in zip(pos, offset))
rec_end = tuple(x + y for x, y in zip((x + text_w, y + text_h), offset))
cv2.rectangle(img, rec_start, rec_end, text_color_bg, -1)
cv2.putText(
img,
text,
(x, int(y + text_h + font_scale - 1)),
font,
font_scale,
text_color,
font_thickness,
cv2.LINE_AA,
)
return img
def annotate_fps(orig_image, fps_text):
draw_text(
orig_image,
f"FPS: {fps_text:0.1f}",
pos=(20, 20),
font_scale=1.0,
text_color=(204, 85, 17),
text_color_bg=(255, 255, 255),
font_thickness=2,
)
return orig_image