sowbaranika13 commited on
Commit
c1dd438
1 Parent(s): 081c58a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +79 -21
app.py CHANGED
@@ -1,34 +1,92 @@
1
  import gradio as gr
2
- import torch
3
- from transformers import AutoModelForImageClassification, AutoTokenizer
4
- from PIL import Image
 
 
 
 
 
 
5
 
6
  # Load your model and tokenizer
7
- model_name = "inceptionv3_class.h5"
8
- model = AutoModelForImageClassification.from_pretrained(model_name)
9
- tokenizer = AutoTokenizer.from_pretrained(model_name)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10
 
11
  def predict(image):
12
- # Preprocess the image
13
- inputs = tokenizer(images=image, return_tensors="pt")
14
-
15
- # Predict
16
- outputs = model(**inputs)
17
- logits = outputs.logits
18
- probs = torch.nn.functional.softmax(logits, dim=-1)
19
-
20
- # Get top 5 predictions
21
- top_probs, top_labels = torch.topk(probs, 5)
22
-
23
- # Map labels to human-readable class names
24
- class_names = [model.config.id2label[label.item()] for label in top_labels[0]]
25
- return {name: prob.item() for name, prob in zip(class_names, top_probs[0])}
26
 
27
  # Create Gradio interface
28
  iface = gr.Interface(
29
  fn=predict,
30
  inputs=gr.inputs.Image(type="pil"),
31
- outputs=gr.outputs.Label(num_top_classes=5),
32
  title="Image Classification",
33
  description="Upload an image to classify it into species and class level."
34
  )
 
1
  import gradio as gr
2
+ import tensorflow as tf
3
+ from tensorflow.keras.models import load_model
4
+ from tensorflow.keras.preprocessing.image import img_to_array, load_img
5
+ import numpy as np
6
+ import os
7
+ import shutil
8
+ from docx import Document
9
+ from docx.shared import Inches
10
+ import matplotlib.pyplot as plt
11
 
12
  # Load your model and tokenizer
13
+ class_thresholds = {
14
+ 'class': 0.6,
15
+ 'amphibia': 0.1,
16
+ 'aves': 0.5,
17
+ 'mammalia': 0.8,
18
+ 'serpentes': 0.9
19
+ }
20
+
21
+ labels = {
22
+ 'class': ['amphibia', 'aves', 'invertebrates', 'lacertilia', 'mammalia', 'serpentes', 'testudines'],
23
+ 'serpentes': ["Butler's Gartersnake", "Dekay's Brownsnake", 'Eastern Gartersnake', 'Eastern Hog-nosed snake', 'Eastern Massasauga', 'Eastern Milksnake', 'Eastern Racer Snake', 'Eastern Ribbonsnake', 'Gray Ratsnake', "Kirtland's Snake", 'Northern Watersnake', 'Plains Gartersnake', 'Red-bellied Snake', 'Smooth Greensnake'],
24
+ 'mammalia': ['American Mink', 'Brown Rat', 'Eastern Chipmunk', 'Eastern Cottontail', 'Long-tailed Weasel', 'Masked Shrew', 'Meadow Jumping Mouse', 'Meadow Vole', 'N. Short-tailed Shrew', 'Raccoon', 'Star-nosed mole', 'Striped Skunk', 'Virginia Opossum', 'White-footed Mouse', 'Woodchuck', 'Woodland Jumping Mouse'],
25
+ 'aves': ['Common Yellowthroat', 'Gray Catbird', 'Indigo Bunting', 'Northern House Wren', 'Song Sparrow', 'Sora'],
26
+ 'amphibia': ['American Bullfrog', 'American Toad', 'Green Frog', 'Northern Leopard Frog']
27
+ }
28
+
29
+ hierarchical_models = {}
30
+ for label in labels:
31
+ model_path = f"inceptionv3_{label}.h5"
32
+ if os.path.exists(model_path):
33
+ hierarchical_models[label] = load_model(model_path)
34
+
35
+ def load_and_preprocess_image(image_path, target_size=(299, 299)):
36
+ img = load_img(image_path, target_size=target_size)
37
+ img_array = img_to_array(img)
38
+ img_array = np.expand_dims(img_array, axis=0)
39
+ img_array = tf.keras.applications.inception_v3.preprocess_input(img_array)
40
+ return img_array
41
+
42
+ def copy_images(source_dir, image_path, label, confidence, threshold):
43
+ if confidence < threshold:
44
+ human_review_dir = os.path.join(source_dir, 'human', label)
45
+ os.makedirs(human_review_dir, exist_ok=True)
46
+ human_review_path = os.path.join(human_review_dir, os.path.basename(image_path))
47
+ shutil.copy(image_path, human_review_path)
48
+ else:
49
+ output_path = os.path.join(source_dir, label)
50
+ os.makedirs(output_path, exist_ok=True)
51
+ output_dir = os.path.join(output_path, os.path.basename(image_path))
52
+ shutil.copy(image_path, output_dir)
53
+
54
+ def process_images(input_d, level, dir_name):
55
+ input_dir = os.path.join(input_d, dir_name)
56
+ image_paths = [os.path.join(input_dir, fname) for fname in os.listdir(input_dir)]
57
+ images = np.vstack([load_and_preprocess_image(image_path) for image_path in image_paths])
58
+ predictions = hierarchical_models[level].predict(images)
59
+ for image_path, prediction in zip(image_paths, predictions):
60
+ predicted_class_index = np.argmax(prediction)
61
+ predicted_class_label = labels[level][predicted_class_index]
62
+ confidence = prediction[predicted_class_index]
63
+ copy_images(input_d, image_path, predicted_class_label, confidence, class_thresholds[level])
64
+ if level == "class":
65
+ for label in ['serpentes', 'mammalia', 'aves', 'amphibia']:
66
+ class_dir = os.path.join(input_d, label)
67
+ if os.path.exists(class_dir):
68
+ process_images(input_d, label, label)
69
 
70
  def predict(image):
71
+ image_path = "temp_image.jpg"
72
+ image.save(image_path)
73
+ level = "class"
74
+ dir_name = "temp_images"
75
+ os.makedirs(dir_name, exist_ok=True)
76
+ shutil.copy(image_path, os.path.join(dir_name, os.path.basename(image_path)))
77
+ process_images(".", level, dir_name)
78
+ predictions = {}
79
+ for label in ['serpentes', 'mammalia', 'aves', 'amphibia']:
80
+ class_dir = os.path.join(".", label)
81
+ if os.path.exists(class_dir):
82
+ predictions[label] = os.listdir(class_dir)
83
+ return predictions
 
84
 
85
  # Create Gradio interface
86
  iface = gr.Interface(
87
  fn=predict,
88
  inputs=gr.inputs.Image(type="pil"),
89
+ outputs=gr.outputs.Label(num_top_classes=7),
90
  title="Image Classification",
91
  description="Upload an image to classify it into species and class level."
92
  )