|
from PIL import Image, ImageDraw, ImageFont |
|
from dotenv import load_dotenv |
|
import matplotlib.pyplot as plt |
|
from moviepy.editor import * |
|
from io import BytesIO |
|
from glob import glob |
|
import gradio as gr |
|
import numpy as np |
|
import random |
|
import requests |
|
import base64 |
|
import boto3 |
|
import uuid |
|
import os |
|
import io |
|
|
|
random.seed() |
|
|
|
load_dotenv() |
|
|
|
AWS_ACCESS_KEY_ID = os.environ.get('AWS_ACCESS_KEY_ID') |
|
AWS_SECRET_ACCESS_KEY = os.environ.get('AWS_SECRET_ACCESS_KEY') |
|
s3 = boto3.client('s3', |
|
aws_access_key_id=AWS_ACCESS_KEY_ID, |
|
aws_secret_access_key=AWS_SECRET_ACCESS_KEY) |
|
|
|
|
|
def upload2aws(img_array): |
|
image = Image.fromarray(img_array) |
|
buffer = io.BytesIO() |
|
image.save(buffer, format='JPEG') |
|
buffer.seek(0) |
|
unique_name = str(uuid.uuid4()) |
|
s3.put_object(Bucket='predict-packages', Key=f'images_webapp_counters/{unique_name}.jpg', Body=buffer) |
|
return None |
|
|
|
def vidupload2aws(vid_path): |
|
vid_name = os.path.basename(vid_path) |
|
_, ext = os.path.splitext(vid_name) |
|
unique_name = str(uuid.uuid4()) |
|
s3.upload_file(vid_path, 'predict-packages', f'images_webapp_counters/videos/{unique_name}{ext}') |
|
return None |
|
|
|
def send2api(input_img, api_url): |
|
buf = io.BytesIO() |
|
plt.imsave(buf, input_img, format='jpg') |
|
files = {'image': buf.getvalue()} |
|
res = requests.post(api_url, files=files) |
|
try: |
|
res.raise_for_status() |
|
if res.status_code != 204: |
|
response = res.json() |
|
except Exception as e: |
|
print(str(e)) |
|
return response |
|
|
|
def displaytext_detclasim(c_cnames, c_scinames, coverage): |
|
countings_list = list(c_scinames.items()) |
|
countings_list.sort(key = lambda x: x[1], reverse=True) |
|
total = 0 |
|
for (_,c) in countings_list: |
|
total += c |
|
|
|
free = 100-int(coverage.split('.')[0]) |
|
text = f'free space = {free}%'+'\n\n' |
|
text += 'Countings by scientific name:\n' |
|
for key,value in countings_list: |
|
text += f'{key} = {value}'+'\n' |
|
text += '\n\n' |
|
text += 'Countings by common name:\n' |
|
|
|
countings_list = list(c_cnames.items()) |
|
countings_list.sort(key = lambda x: x[1], reverse=True) |
|
for key,value in countings_list: |
|
text += f'{key} = {value}'+'\n' |
|
text += '\n' |
|
text += f'total = {total}'+'\n' |
|
return text |
|
|
|
def displaytext_yolocounter(countings, coverage): |
|
countings_list = list(countings.items()) |
|
countings_list.sort(key = lambda x: x[1], reverse=True) |
|
total = 0 |
|
for (y_class,c) in countings_list: |
|
total += c |
|
|
|
free = 100-int(coverage.split('.')[0]) |
|
text = f'free space = {free}%'+'\n\n' |
|
for key,value in countings_list: |
|
text += f'{key} = {value}'+'\n' |
|
text += '\n' |
|
text += f'total = {total}'+'\n' |
|
return text |
|
|
|
def display_detectionsandcountings_directcounter(img_array, countings, prob_th=0, cth = 0): |
|
img = Image.fromarray(img_array) |
|
img1 = ImageDraw.Draw(img) |
|
h, w = img.size |
|
ratio = h/4000 |
|
|
|
countings_list = list(countings.items()) |
|
countings_list.sort(key = lambda x: x[1], reverse=True) |
|
yi=int(20*ratio) |
|
total = 0 |
|
for (y_class,c) in countings_list: |
|
if c > cth: |
|
img1.text((int(50*ratio), yi), "# {} = {}".format(y_class, c), fill='red') |
|
yi += int(100*ratio) |
|
total += c |
|
yi += int(100*ratio) |
|
img1.text((int(50*ratio), yi), "# {} = {}".format('total', total), fill='red') |
|
|
|
text = '' |
|
for key,value in countings_list: |
|
text += f'{key} = {value}'+'\n' |
|
text += '\n' |
|
text += f'total = {total}'+'\n' |
|
return img, text |
|
|
|
def testing_countingid(input_img): |
|
upload2aws(input_img) |
|
|
|
api_url = 'http://countingid-test.us-east-1.elasticbeanstalk.com/predict' |
|
response = send2api(input_img, api_url) |
|
c_cnames = response['countings_cnames'] |
|
c_scinames = response['countings_scinames'] |
|
coverage = response['coverage'] |
|
detections = response['detections'] |
|
img_out = response['img_out'] |
|
img = Image.open(BytesIO(base64.b64decode(img_out))) |
|
text = displaytext_detclasim(c_cnames, c_scinames, coverage) |
|
return img, text |
|
|
|
def testing_yolocounter(input_img): |
|
api_url = 'http://yolocounter-test.us-east-1.elasticbeanstalk.com/predict' |
|
response = send2api(input_img, api_url) |
|
countings = response['countings_scinames'] |
|
coverage = response['coverage'] |
|
detections = response['detections'] |
|
img_out = response['img_out'] |
|
img = Image.open(BytesIO(base64.b64decode(img_out))) |
|
text = displaytext_yolocounter(countings, coverage) |
|
return img, text |
|
|
|
def testing_directcounter(input_img): |
|
api_url = 'http://directcounter-test.us-east-1.elasticbeanstalk.com/predict' |
|
response = send2api(input_img, api_url) |
|
countings = response['countings_scinames'] |
|
img, text = display_detectionsandcountings_directcounter(input_img, countings, prob_th=0, cth = 0) |
|
return img, text |
|
|
|
|
|
def extractframes(vid_path): |
|
clip = VideoFileClip(vid_path) |
|
clip = clip.subclip(0, 10) |
|
stride = int(round(clip.fps,0))/2 |
|
frames = [] |
|
for i,frame in enumerate(clip.iter_frames()): |
|
if i % stride == 0: |
|
frames.append(frame) |
|
return frames |
|
|
|
def processvideo(vid_path, detector): |
|
frames = extractframes(vid_path) |
|
img_list = [] |
|
for frame in frames: |
|
img, text = detector(frame) |
|
img_list.append(np.asarray(img)) |
|
clip = ImageSequenceClip(img_list, fps=2) |
|
outvid_path = "out.mp4" |
|
clip.write_videofile(outvid_path) |
|
return outvid_path |
|
|
|
def video_identity(video): |
|
vidupload2aws(video) |
|
video = processvideo(video, detector = testing_yolocounter) |
|
return video |
|
|
|
|
|
|
|
with gr.Blocks() as demo: |
|
gr.Markdown("Submit an image with insects in a trap") |
|
|
|
with gr.Tab("Species & Common Name Count"): |
|
with gr.Row(): |
|
input1 = gr.Image() |
|
output1 =[gr.Image(height=500, width=500), gr.Textbox(lines=20)] |
|
button1 = gr.Button("Submit") |
|
button1.click(testing_countingid, input1, output1) |
|
|
|
with gr.Tab("Simplified Scientific Name Count"): |
|
with gr.Row(): |
|
|
|
output2 =[gr.Image(height=500, width=500), gr.Textbox(lines=20)] |
|
|
|
button1.click(testing_yolocounter, input1, output2) |
|
|
|
with gr.Tab("Procesing on a video (under development)"): |
|
with gr.Row(): |
|
input3 = gr.Video() |
|
output3 = gr.Video() |
|
button3 = gr.Button("Submit") |
|
button3.click(video_identity, input3, output3) |
|
|
|
examples_list = glob("img_examples/*.jpg") |
|
random.shuffle(examples_list) |
|
examples = gr.Examples(examples=examples_list[:4],inputs=[input1]) |
|
|
|
""" with gr.Tab("Direct insect counter"): |
|
with gr.Row(): |
|
#input3 = gr.Image() |
|
output3 =[gr.Image(height=500, width=500), gr.Textbox(lines=20)] |
|
#button3 = gr.Button("Submit") |
|
button1.click(testing_directcounter, input1, output3) |
|
""" |
|
|
|
demo.launch() |