Spaces:
Sleeping
Sleeping
| from PIL import Image, ImageDraw, ImageFont | |
| from dotenv import load_dotenv | |
| import matplotlib.pyplot as plt | |
| from moviepy.editor import * | |
| from io import BytesIO | |
| from glob import glob | |
| import gradio as gr | |
| import numpy as np | |
| import random | |
| import requests | |
| import base64 | |
| import boto3 | |
| import uuid | |
| import os | |
| import io | |
| random.seed() | |
| load_dotenv() | |
| AWS_ACCESS_KEY_ID = os.environ.get('AWS_ACCESS_KEY_ID') | |
| AWS_SECRET_ACCESS_KEY = os.environ.get('AWS_SECRET_ACCESS_KEY') | |
| s3 = boto3.client('s3', | |
| aws_access_key_id=AWS_ACCESS_KEY_ID, | |
| aws_secret_access_key=AWS_SECRET_ACCESS_KEY) | |
| def upload2aws(img_array): | |
| image = Image.fromarray(img_array) | |
| buffer = io.BytesIO() | |
| image.save(buffer, format='JPEG') | |
| buffer.seek(0) | |
| unique_name = str(uuid.uuid4()) | |
| s3.put_object(Bucket='predict-packages', Key=f'images_webapp_counters/{unique_name}.jpg', Body=buffer) | |
| return None | |
| def vidupload2aws(vid_path): | |
| vid_name = os.path.basename(vid_path) | |
| _, ext = os.path.splitext(vid_name) | |
| unique_name = str(uuid.uuid4()) | |
| s3.upload_file(vid_path, 'predict-packages', f'images_webapp_counters/videos/{unique_name}{ext}') | |
| return None | |
| def send2api(input_img, api_url): | |
| buf = io.BytesIO() | |
| plt.imsave(buf, input_img, format='jpg') | |
| files = {'image': buf.getvalue()} | |
| res = requests.post(api_url, files=files) | |
| try: | |
| res.raise_for_status() | |
| if res.status_code != 204: | |
| response = res.json() | |
| except Exception as e: | |
| print(str(e)) | |
| return response | |
| def displaytext_detclasim(c_cnames, c_scinames, coverage): | |
| countings_list = list(c_scinames.items()) | |
| countings_list.sort(key = lambda x: x[1], reverse=True) | |
| total = 0 | |
| for (_,c) in countings_list: | |
| total += c | |
| free = 100-int(coverage.split('.')[0]) | |
| text = f'free space = {free}%'+'\n\n' | |
| text += 'Countings by scientific name:\n' | |
| for key,value in countings_list: | |
| text += f'{key} = {value}'+'\n' | |
| text += '\n\n' | |
| text += 'Countings by common name:\n' | |
| countings_list = list(c_cnames.items()) | |
| countings_list.sort(key = lambda x: x[1], reverse=True) | |
| for key,value in countings_list: | |
| text += f'{key} = {value}'+'\n' | |
| text += '\n' | |
| text += f'total = {total}'+'\n' | |
| return text | |
| def displaytext_yolocounter(countings, coverage): | |
| countings_list = list(countings.items()) | |
| countings_list.sort(key = lambda x: x[1], reverse=True) | |
| total = 0 | |
| for (y_class,c) in countings_list: | |
| total += c | |
| free = 100-int(coverage.split('.')[0]) | |
| text = f'free space = {free}%'+'\n\n' | |
| for key,value in countings_list: | |
| text += f'{key} = {value}'+'\n' | |
| text += '\n' | |
| text += f'total = {total}'+'\n' | |
| return text | |
| def display_detectionsandcountings_directcounter(img_array, countings, prob_th=0, cth = 0): | |
| img = Image.fromarray(img_array) | |
| img1 = ImageDraw.Draw(img) | |
| h, w = img.size | |
| ratio = h/4000 | |
| countings_list = list(countings.items()) | |
| countings_list.sort(key = lambda x: x[1], reverse=True) | |
| yi=int(20*ratio) | |
| total = 0 | |
| for (y_class,c) in countings_list: | |
| if c > cth: | |
| img1.text((int(50*ratio), yi), "# {} = {}".format(y_class, c), fill='red') | |
| yi += int(100*ratio) | |
| total += c | |
| yi += int(100*ratio) | |
| img1.text((int(50*ratio), yi), "# {} = {}".format('total', total), fill='red') | |
| text = '' | |
| for key,value in countings_list: | |
| text += f'{key} = {value}'+'\n' | |
| text += '\n' | |
| text += f'total = {total}'+'\n' | |
| return img, text | |
| def testing_countingid(input_img): | |
| upload2aws(input_img) | |
| api_url = 'http://countingid-test.us-east-1.elasticbeanstalk.com/predict' | |
| response = send2api(input_img, api_url) | |
| c_cnames = response['countings_cnames'] | |
| c_scinames = response['countings_scinames'] | |
| coverage = response['coverage'] | |
| detections = response['detections'] | |
| img_out = response['img_out'] | |
| img = Image.open(BytesIO(base64.b64decode(img_out))) | |
| text = displaytext_detclasim(c_cnames, c_scinames, coverage) | |
| return img, text | |
| def testing_yolocounter(input_img): | |
| api_url = 'http://yolocounter-test.us-east-1.elasticbeanstalk.com/predict' | |
| response = send2api(input_img, api_url) | |
| countings = response['countings_scinames'] | |
| coverage = response['coverage'] | |
| detections = response['detections'] | |
| img_out = response['img_out'] | |
| img = Image.open(BytesIO(base64.b64decode(img_out))) | |
| text = displaytext_yolocounter(countings, coverage) | |
| return img, text | |
| def testing_directcounter(input_img): | |
| api_url = 'http://directcounter-test.us-east-1.elasticbeanstalk.com/predict' | |
| response = send2api(input_img, api_url) | |
| countings = response['countings_scinames'] | |
| img, text = display_detectionsandcountings_directcounter(input_img, countings, prob_th=0, cth = 0) | |
| return img, text | |
| #------------------------------------------------ | |
| def extractframes(vid_path): | |
| clip = VideoFileClip(vid_path) | |
| clip = clip.subclip(0, 10) #Cut first 10 seconds | |
| stride = int(round(clip.fps,0))/2 #Subsample 2 fps | |
| frames = [] | |
| for i,frame in enumerate(clip.iter_frames()): | |
| if i % stride == 0: | |
| frames.append(frame) | |
| return frames | |
| def processvideo(vid_path, detector): | |
| frames = extractframes(vid_path) | |
| img_list = [] | |
| for frame in frames: | |
| img, text = detector(frame) | |
| img_list.append(np.asarray(img)) | |
| clip = ImageSequenceClip(img_list, fps=2) | |
| outvid_path = "out.mp4" | |
| clip.write_videofile(outvid_path) | |
| return outvid_path | |
| def video_identity(video): | |
| vidupload2aws(video) | |
| video = processvideo(video, detector = testing_yolocounter) | |
| return video | |
| #------------------------------------------------- | |
| with gr.Blocks() as demo: | |
| gr.Markdown("Submit an image with insects in a trap") | |
| with gr.Tab("Species & Common Name Count"): | |
| with gr.Row(): | |
| input1 = gr.Image() | |
| output1 =[gr.Image(height=500, width=500), gr.Textbox(lines=20)] | |
| button1 = gr.Button("Submit") | |
| button1.click(testing_countingid, input1, output1) | |
| with gr.Tab("Simplified Scientific Name Count"): | |
| with gr.Row(): | |
| #input2 = gr.Image() | |
| output2 =[gr.Image(height=500, width=500), gr.Textbox(lines=20)] | |
| #button2 = gr.Button("Submit") | |
| button1.click(testing_yolocounter, input1, output2) | |
| with gr.Tab("Procesing on a video (under development)"): | |
| with gr.Row(): | |
| input3 = gr.Video() | |
| output3 = gr.Video() | |
| button3 = gr.Button("Submit") | |
| button3.click(video_identity, input3, output3) | |
| examples_list = glob("img_examples/*.jpg") | |
| random.shuffle(examples_list) | |
| examples = gr.Examples(examples=examples_list[:4],inputs=[input1]) | |
| """ with gr.Tab("Direct insect counter"): | |
| with gr.Row(): | |
| #input3 = gr.Image() | |
| output3 =[gr.Image(height=500, width=500), gr.Textbox(lines=20)] | |
| #button3 = gr.Button("Submit") | |
| button1.click(testing_directcounter, input1, output3) | |
| """ | |
| demo.launch() |