ezamorag's picture
Update app.py
d5025eb
raw
history blame contribute delete
No virus
7.13 kB
from PIL import Image, ImageDraw, ImageFont
from dotenv import load_dotenv
import matplotlib.pyplot as plt
from moviepy.editor import *
from io import BytesIO
from glob import glob
import gradio as gr
import numpy as np
import random
import requests
import base64
import boto3
import uuid
import os
import io
random.seed()
load_dotenv()
AWS_ACCESS_KEY_ID = os.environ.get('AWS_ACCESS_KEY_ID')
AWS_SECRET_ACCESS_KEY = os.environ.get('AWS_SECRET_ACCESS_KEY')
s3 = boto3.client('s3',
aws_access_key_id=AWS_ACCESS_KEY_ID,
aws_secret_access_key=AWS_SECRET_ACCESS_KEY)
def upload2aws(img_array):
image = Image.fromarray(img_array)
buffer = io.BytesIO()
image.save(buffer, format='JPEG')
buffer.seek(0)
unique_name = str(uuid.uuid4())
s3.put_object(Bucket='predict-packages', Key=f'images_webapp_counters/{unique_name}.jpg', Body=buffer)
return None
def vidupload2aws(vid_path):
vid_name = os.path.basename(vid_path)
_, ext = os.path.splitext(vid_name)
unique_name = str(uuid.uuid4())
s3.upload_file(vid_path, 'predict-packages', f'images_webapp_counters/videos/{unique_name}{ext}')
return None
def send2api(input_img, api_url):
buf = io.BytesIO()
plt.imsave(buf, input_img, format='jpg')
files = {'image': buf.getvalue()}
res = requests.post(api_url, files=files)
try:
res.raise_for_status()
if res.status_code != 204:
response = res.json()
except Exception as e:
print(str(e))
return response
def displaytext_detclasim(c_cnames, c_scinames, coverage):
countings_list = list(c_scinames.items())
countings_list.sort(key = lambda x: x[1], reverse=True)
total = 0
for (_,c) in countings_list:
total += c
free = 100-int(coverage.split('.')[0])
text = f'free space = {free}%'+'\n\n'
text += 'Countings by scientific name:\n'
for key,value in countings_list:
text += f'{key} = {value}'+'\n'
text += '\n\n'
text += 'Countings by common name:\n'
countings_list = list(c_cnames.items())
countings_list.sort(key = lambda x: x[1], reverse=True)
for key,value in countings_list:
text += f'{key} = {value}'+'\n'
text += '\n'
text += f'total = {total}'+'\n'
return text
def displaytext_yolocounter(countings, coverage):
countings_list = list(countings.items())
countings_list.sort(key = lambda x: x[1], reverse=True)
total = 0
for (y_class,c) in countings_list:
total += c
free = 100-int(coverage.split('.')[0])
text = f'free space = {free}%'+'\n\n'
for key,value in countings_list:
text += f'{key} = {value}'+'\n'
text += '\n'
text += f'total = {total}'+'\n'
return text
def display_detectionsandcountings_directcounter(img_array, countings, prob_th=0, cth = 0):
img = Image.fromarray(img_array)
img1 = ImageDraw.Draw(img)
h, w = img.size
ratio = h/4000
countings_list = list(countings.items())
countings_list.sort(key = lambda x: x[1], reverse=True)
yi=int(20*ratio)
total = 0
for (y_class,c) in countings_list:
if c > cth:
img1.text((int(50*ratio), yi), "# {} = {}".format(y_class, c), fill='red')
yi += int(100*ratio)
total += c
yi += int(100*ratio)
img1.text((int(50*ratio), yi), "# {} = {}".format('total', total), fill='red')
text = ''
for key,value in countings_list:
text += f'{key} = {value}'+'\n'
text += '\n'
text += f'total = {total}'+'\n'
return img, text
def testing_countingid(input_img):
upload2aws(input_img)
api_url = 'http://countingid-test.us-east-1.elasticbeanstalk.com/predict'
response = send2api(input_img, api_url)
c_cnames = response['countings_cnames']
c_scinames = response['countings_scinames']
coverage = response['coverage']
detections = response['detections']
img_out = response['img_out']
img = Image.open(BytesIO(base64.b64decode(img_out)))
text = displaytext_detclasim(c_cnames, c_scinames, coverage)
return img, text
def testing_yolocounter(input_img):
api_url = 'http://yolocounter-test.us-east-1.elasticbeanstalk.com/predict'
response = send2api(input_img, api_url)
countings = response['countings_scinames']
coverage = response['coverage']
detections = response['detections']
img_out = response['img_out']
img = Image.open(BytesIO(base64.b64decode(img_out)))
text = displaytext_yolocounter(countings, coverage)
return img, text
def testing_directcounter(input_img):
api_url = 'http://directcounter-test.us-east-1.elasticbeanstalk.com/predict'
response = send2api(input_img, api_url)
countings = response['countings_scinames']
img, text = display_detectionsandcountings_directcounter(input_img, countings, prob_th=0, cth = 0)
return img, text
#------------------------------------------------
def extractframes(vid_path):
clip = VideoFileClip(vid_path)
clip = clip.subclip(0, 10) #Cut first 10 seconds
stride = int(round(clip.fps,0))/2 #Subsample 2 fps
frames = []
for i,frame in enumerate(clip.iter_frames()):
if i % stride == 0:
frames.append(frame)
return frames
def processvideo(vid_path, detector):
frames = extractframes(vid_path)
img_list = []
for frame in frames:
img, text = detector(frame)
img_list.append(np.asarray(img))
clip = ImageSequenceClip(img_list, fps=2)
outvid_path = "out.mp4"
clip.write_videofile(outvid_path)
return outvid_path
def video_identity(video):
vidupload2aws(video)
video = processvideo(video, detector = testing_yolocounter)
return video
#-------------------------------------------------
with gr.Blocks() as demo:
gr.Markdown("Submit an image with insects in a trap")
with gr.Tab("Species & Common Name Count"):
with gr.Row():
input1 = gr.Image()
output1 =[gr.Image(height=500, width=500), gr.Textbox(lines=20)]
button1 = gr.Button("Submit")
button1.click(testing_countingid, input1, output1)
with gr.Tab("Simplified Scientific Name Count"):
with gr.Row():
#input2 = gr.Image()
output2 =[gr.Image(height=500, width=500), gr.Textbox(lines=20)]
#button2 = gr.Button("Submit")
button1.click(testing_yolocounter, input1, output2)
with gr.Tab("Procesing on a video (under development)"):
with gr.Row():
input3 = gr.Video()
output3 = gr.Video()
button3 = gr.Button("Submit")
button3.click(video_identity, input3, output3)
examples_list = glob("img_examples/*.jpg")
random.shuffle(examples_list)
examples = gr.Examples(examples=examples_list[:4],inputs=[input1])
""" with gr.Tab("Direct insect counter"):
with gr.Row():
#input3 = gr.Image()
output3 =[gr.Image(height=500, width=500), gr.Textbox(lines=20)]
#button3 = gr.Button("Submit")
button1.click(testing_directcounter, input1, output3)
"""
demo.launch()