Spaces:
Sleeping
Sleeping
import numpy as np | |
import cv2 | |
from skimage import transform | |
import matplotlib.pyplot as plt | |
from PIL import Image | |
# def inference_annotations( | |
# outputs, detection_threshold, classes, | |
# colors, orig_image | |
# ): | |
# boxes = outputs[0]['boxes'].data.numpy() | |
# scores = outputs[0]['scores'].data.numpy() | |
# # Filter out boxes according to `detection_threshold`. | |
# boxes = boxes[scores >= detection_threshold].astype(np.int32) | |
# draw_boxes = boxes.copy() | |
# # Get all the predicited class names. | |
# pred_classes = [classes[i] for i in outputs[0]['labels'].cpu().numpy()] | |
# lw = max(round(sum(orig_image.shape) / 2 * 0.003), 2) # Line width. | |
# tf = max(lw - 1, 1) # Font thickness. | |
# # Draw the bounding boxes and write the class name on top of it. | |
# for j, box in enumerate(draw_boxes): | |
# p1, p2 = (int(box[0]), int(box[1])), (int(box[2]), int(box[3])) | |
# class_name = pred_classes[j] | |
# color = colors[classes.index(class_name)] | |
# cv2.rectangle( | |
# orig_image, | |
# p1, p2, | |
# color=(0,0,255), | |
# thickness=lw, | |
# lineType=cv2.LINE_AA | |
# ) | |
# # For filled rectangle. | |
# w, h = cv2.getTextSize( | |
# class_name, | |
# 0, | |
# fontScale=lw / 3, | |
# thickness=tf | |
# )[0] # text width, height | |
# outside = p1[1] - h >= 3 | |
# p2 = p1[0] + w, p1[1] - h - 3 if outside else p1[1] + h + 3 | |
# cv2.rectangle( | |
# orig_image, | |
# p1, | |
# p2, | |
# color=(0,0,255), | |
# thickness=-1, | |
# lineType=cv2.LINE_AA | |
# ) | |
# cv2.putText( | |
# orig_image, | |
# class_name, | |
# (p1[0], p1[1] - 5 if outside else p1[1] + h + 2), | |
# cv2.FONT_HERSHEY_SIMPLEX, | |
# fontScale=lw / 3.8, | |
# color=(255, 255, 255), | |
# thickness=tf, | |
# lineType=cv2.LINE_AA | |
# ) | |
# return orig_image | |
def CNNpostAnnotations( | |
outputs, detection_threshold, classes, | |
colors, orig_image, CNN | |
): | |
imgCellVals = [] | |
mod = orig_image.copy() | |
boxes = outputs[0]['boxes'].data.numpy() | |
scores = outputs[0]['scores'].data.numpy() | |
# Filter out boxes according to `detection_threshold`. | |
boxes = boxes[scores >= detection_threshold].astype(np.int32) | |
draw_boxes = boxes.copy() | |
# Get all the predicited class names. | |
pred_classes = [classes[i] for i in outputs[0]['labels'].cpu().numpy()] | |
lw = max(round(sum(orig_image.shape) / 2 * 0.003), 2) # Line width. | |
tf = max(lw - 1, 1) # Font thickness. | |
# Draw the bounding boxes and write the class name on top of it. | |
for j, box in enumerate(draw_boxes): | |
p1, p2 = (int(box[0]), int(box[1])), (int(box[2]), int(box[3])) | |
run = orig_image[int(box[1]):int(box[3]), int(box[0]):int(box[2])] | |
imgr = Image.fromarray(run, 'RGB') | |
np_image = run.astype('float32')/255 | |
np_image = transform.resize(np_image, (200, 200, 3)) | |
np_image = np.expand_dims(np_image, axis=0) | |
class_index = {0: 'Gametocyte', 1: 'RBC', 2: 'Ring', 3: 'Schizont', 4: 'Trophozoite'} | |
classText = class_index[np.argmax(CNN.predict(np_image))] | |
class_name = pred_classes[j] | |
cv2.rectangle( | |
mod, | |
p1, p2, | |
color=(0,0,255), | |
thickness=lw, | |
lineType=cv2.LINE_AA | |
) | |
# For filled rectangle. | |
w, h = cv2.getTextSize( | |
classText, | |
0, | |
fontScale=lw / 3, | |
thickness=tf | |
)[0] # text width, height | |
outside = p1[1] - h >= 3 | |
p2 = p1[0] + w, p1[1] - h - 3 if outside else p1[1] + h + 3 | |
cv2.rectangle( | |
mod, | |
p1, | |
p2, | |
color=(0,0,255), | |
thickness=-1, | |
lineType=cv2.LINE_AA | |
) | |
cv2.putText( | |
mod, | |
classText, | |
(p1[0], p1[1] - 5 if outside else p1[1] + h + 2), | |
cv2.FONT_HERSHEY_SIMPLEX, | |
fontScale=lw / 3.8, | |
color=(255, 255, 255), | |
thickness=tf, | |
lineType=cv2.LINE_AA | |
) | |
imgCellVals.append([imgr, classText]) | |
return mod, imgCellVals | |
def draw_text( | |
img, | |
text, | |
font=cv2.FONT_HERSHEY_SIMPLEX, | |
pos=(0, 0), | |
font_scale=1, | |
font_thickness=2, | |
text_color=(0, 255, 0), | |
text_color_bg=(0, 0, 0), | |
): | |
offset = (5, 5) | |
x, y = pos | |
text_size, _ = cv2.getTextSize(text, font, font_scale, font_thickness) | |
text_w, text_h = text_size | |
rec_start = tuple(x - y for x, y in zip(pos, offset)) | |
rec_end = tuple(x + y for x, y in zip((x + text_w, y + text_h), offset)) | |
cv2.rectangle(img, rec_start, rec_end, text_color_bg, -1) | |
cv2.putText( | |
img, | |
text, | |
(x, int(y + text_h + font_scale - 1)), | |
font, | |
font_scale, | |
text_color, | |
font_thickness, | |
cv2.LINE_AA, | |
) | |
return img | |
def annotate_fps(orig_image, fps_text): | |
draw_text( | |
orig_image, | |
f"FPS: {fps_text:0.1f}", | |
pos=(20, 20), | |
font_scale=1.0, | |
text_color=(204, 85, 17), | |
text_color_bg=(255, 255, 255), | |
font_thickness=2, | |
) | |
return orig_image |