Spaces:
Sleeping
Sleeping
File size: 5,563 Bytes
733459a 9dc317c 1fb7b02 9dc317c 6883ad5 9dc317c 733459a 9dc317c 2cf2254 11bfa27 9dc317c 6883ad5 2cf2254 ab1042a 9dc317c 24f2e31 9dc317c 733459a 9dc317c 1fb7b02 9dc317c 1fb7b02 24f2e31 1fb7b02 733459a 9dc317c 934856d 9dc317c 1fb7b02 9dc317c 1fb7b02 9dc317c 1fb7b02 9dc317c 1fb7b02 9dc317c 24f2e31 9dc317c e540259 9dc317c 24f2e31 9dc317c e540259 9dc317c 934856d 9dc317c e5fb4db 934856d 9dc317c 733459a 9dc317c 733459a 9dc317c 1fb7b02 9dc317c 1fb7b02 9dc317c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 |
import functools
import gdown
from collections import Counter
import os
import torch
from S1_CNN_Model import CNN_Model
import gradio as gr
import numpy as np
import cv2
from SpeciesDetail import labels, SpeciesDetail
device=torch.device('cuda' if torch.cuda.is_available() else 'cpu')
MODEL_LINK = "https://drive.google.com/file/d/18-t2jMpXLxtqE-8Bu0_NNNuie_mguSON/view?usp=sharing"
MODEL_PATH = "model.pt"
if not os.path.exists(MODEL_PATH):
print("Downloading model . . . ")
gdown.download(MODEL_LINK,MODEL_PATH,fuzzy=True)
model:CNN_Model = torch.load(MODEL_PATH)
model.to(device)
model.device = device
def listdir_full(path: str) -> list[str]:
return [f"{path}/{p}" for p in os.listdir(path)]
label_names = [l.name for l in labels]
class History():
cols = ["Image", "Prediction"]
def __init__(self, img, name) -> None:
self.img = resize_image(img)
self.name = name
import sqlite3
def fetch_data(id: int):
with sqlite3.connect('my_database.db') as conn:
c = conn.cursor()
c.execute('SELECT * FROM my_table WHERE id = ?', (id,))
_, *detail = c.fetchone()
return SpeciesDetail(*detail)
MAX_IMG_LEN = 160
def resize_image(img):
h, w, _ = img.shape
if w > h:
w1 = MAX_IMG_LEN
h1 = int(h/w * MAX_IMG_LEN)
else:
h1 = MAX_IMG_LEN
w1 = int(w/h * MAX_IMG_LEN)
return cv2.resize(img,(w1,h1))
PD_COLS=["image","predicted species"]
MAX_HISTORY = 10
MAX_PREDS = 10
def classify(image: np.array, history):
if history == None: history = []
with torch.no_grad():
r, p = model.predict_large_image(cv2.cvtColor(image, cv2. COLOR_RGB2BGR))
ratios = [gr.Textbox(f"{label_names[label]}: {count/len(r)*100:.2f}%",visible=True)
for label, count in Counter(r.tolist()).most_common()][-MAX_PREDS:]
ratios += [gr.Textbox(visible=False)] * (MAX_PREDS - len(ratios))
detail = fetch_data(p.item())
pred = gr.Markdown(detail.result_text())
history += [(resize_image(image), f"<h2>{detail.name}</h2> \n {detail.desc}")]
hist = history[-MAX_HISTORY:]
return pred, *ratios, *toggle_history_components(hist), history
def toggle_history_components(history: list[History]):
n_hidden = MAX_HISTORY - len(history)
images, names = list(zip(*history))
components = [gr.Image(x, visible=True) for x in images]
components += [gr.Image(visible=False)] * n_hidden
components += [gr.Markdown(x, visible=True) for x in names]
components += [gr.Markdown(visible=False)] * n_hidden
return components
def classification_tab():
with gr.Row():
with gr.Column():
image = gr.Image()
with gr.Row():
submit = gr.Button("Submit", variant='primary')
clear = gr.ClearButton(image)
with gr.Column():
pred = gr.Markdown("## Predictions")
ratios = []
for _ in range(MAX_PREDS):
ratios.append(gr.Textbox(show_label=False,visible=False))
return image, submit, clear, pred, ratios
SAMPLE_DIR = "data/image/test_full"
MAX_SAMPLE_COUNT = max([len(os.listdir(x)) for x in listdir_full(SAMPLE_DIR)])
def sample_tab(image_input, tabs):
def choose_image(image):
return gr.Image(image), gr.Image(image), gr.Tabs(selected=0)
def refresh_samples(species):
images = listdir_full(f"{SAMPLE_DIR}/{species}")
n_hidden = MAX_SAMPLE_COUNT-len(images)
components = [gr.Image(i,visible=True) for i in images]
components += [gr.Image(visible=False)] * n_hidden
components += [gr.Button(visible=True) for _ in images]
components += [gr.Button(visible=False)] * n_hidden
return components
dropdown = gr.Dropdown(label_names, label="Species", value="Select a Species")
images = []
buttons = []
def sample_panel():
with gr.Column():
image = gr.Image(visible=False ,interactive=False, min_width=1)
select = gr.Button("Submit", variant='primary', visible=False)
images.append(image)
buttons.append(select)
select.click(choose_image, image, [image, image_input, tabs])
with gr.Row(): [sample_panel() for _ in range(MAX_SAMPLE_COUNT)]
dropdown.change(refresh_samples, dropdown, images+buttons)
return
def history_tab():
history_imgs = []
history_names = []
with gr.Row():
gr.Markdown("# Image")
with gr.Column(scale=2):
gr.Markdown("# Species")
with gr.Column():
for _ in range(MAX_HISTORY):
with gr.Row():
history_imgs.append(gr.Image(height=200,visible=False))
with gr.Column(scale=2):
history_names.append(gr.Markdown("",visible=False))
return history_imgs + history_names
with open('homepage.md', 'r') as file:
home_screen_markdown = file.read()
with gr.Blocks() as demo:
history = gr.State([])
with gr.Tabs() as tabs:
with gr.Tab("Home", id=3):
gr.Markdown(home_screen_markdown)
with gr.Tab("Classification", id=0):
image, submit, clear, pred, ratios = classification_tab()
with gr.Tab("Samples", id=1):
sample_tab(image, tabs)
with gr.Tab("History", id=2):
table_contents = history_tab()
# history = gr.Gallery(interactive=False)
submit.click(classify,[image, history],[pred, *ratios, *table_contents, history])
demo.launch()
|