insecta / app.py
MuGeminorum
add output txt
70c99ee
raw
history blame contribute delete
No virus
3.82 kB
import os
import cv2
import khandy
import requests
import numpy as np
import gradio as gr
from tqdm import tqdm
from PIL import Image
from insectid import InsectDetector
from insectid import InsectIdentifier
def download_model(url, local_path):
# Check if the file exists
if not os.path.exists(local_path):
print(f"Downloading file from {url}...")
# Make a request to the URL
response = requests.get(url, stream=True)
# Get the total file size in bytes
total_size = int(response.headers.get("content-length", 0))
# Initialize the tqdm progress bar
progress_bar = tqdm(total=total_size, unit="B", unit_scale=True)
# Open a local file with write-binary mode
with open(local_path, "wb") as file:
for data in response.iter_content(chunk_size=1024):
# Update the progress bar
progress_bar.update(len(data))
# Write the data to the local file
file.write(data)
# Close the progress bar
progress_bar.close()
print("Download completed.")
def inference(filename):
if not filename:
filename = "./examples/butterfly.jpg"
detector = InsectDetector()
identifier = InsectIdentifier()
image = khandy.imread(filename)
if image is None:
return None
if max(image.shape[:2]) > 1280:
image = khandy.resize_image_long(image, 1280)
image_for_draw = image.copy()
image_height, image_width = image.shape[:2]
boxes, confs, classes = detector.detect(image)
text = "Unknown"
for box, _, _ in zip(boxes, confs, classes):
box = box.astype(np.int32)
box_width = box[2] - box[0] + 1
box_height = box[3] - box[1] + 1
if box_width < 30 or box_height < 30:
continue
cropped = khandy.crop_or_pad(image, box[0], box[1], box[2], box[3])
results = identifier.identify(cropped)
print(results[0])
prob = results[0]["probability"]
if prob >= 0.10:
text = "{}: {:.2f}%".format(
results[0]["latin_name"], 100.0 * results[0]["probability"]
)
position = [box[0] + 2, box[1] - 20]
position[0] = min(max(position[0], 0), image_width)
position[1] = min(max(position[1], 0), image_height)
cv2.rectangle(
image_for_draw, (box[0], box[1]), (box[2], box[3]), (0, 255, 0), 2
)
image_for_draw = khandy.draw_text(
image_for_draw, text, position, font="simsun.ttc", font_size=15
)
outxt = text.split(":")[0] if ":" in text else text
return Image.fromarray(image_for_draw[:, :, ::-1], mode="RGB"), outxt
domain = "https://huggingface.co/MuGeminorum/insecta/resolve/main/"
domain_zh = "https://www.modelscope.cn/api/v1/models/MuGeminorum/insecta/repo?Revision=master&FilePath="
try:
download_model(
f"{domain}quarrying_insect_detector.onnx",
"./insectid/models/quarrying_insect_detector.onnx",
)
download_model(
f"{domain}quarrying_insect_identifier.onnx",
"./insectid/models/quarrying_insect_identifier.onnx",
)
except Exception:
download_model(
f"{domain_zh}quarrying_insect_detector.onnx",
"./insectid/models/quarrying_insect_detector.onnx",
)
download_model(
f"{domain_zh}quarrying_insect_identifier.onnx",
"./insectid/models/quarrying_insect_identifier.onnx",
)
iface = gr.Interface(
fn=inference,
inputs=gr.Image(label="Upload insect photo", type="filepath"),
outputs=[
gr.Image(label="Detection result"),
gr.Textbox(label="Most probable species", show_copy_button=True),
],
examples=["./examples/butterfly.jpg", "./examples/beetle.jpg"],
)
iface.launch()