sowbaranika13 commited on
Commit
b8bf193
·
verified ·
1 Parent(s): bcc5edd

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +33 -9
app.py CHANGED
@@ -1,12 +1,21 @@
1
  import gradio as gr
2
  import tensorflow as tf
3
  from tensorflow.keras.models import load_model
4
- from tensorflow.keras.preprocessing.image import img_to_array, load_img
5
  import numpy as np
6
  import os
7
  import requests
8
 
9
- # Load your model and tokenizer
 
 
 
 
 
 
 
 
 
10
  labels = {
11
  'class': ['amphibia', 'aves', 'invertebrates', 'lacertilia', 'mammalia', 'serpentes', 'testudines'],
12
  'serpentes': ["Butler's Gartersnake", "Dekay's Brownsnake", 'Eastern Gartersnake', 'Eastern Hog-nosed snake', 'Eastern Massasauga', 'Eastern Milksnake', 'Eastern Racer Snake', 'Eastern Ribbonsnake', 'Gray Ratsnake', "Kirtland's Snake", 'Northern Watersnake', 'Plains Gartersnake', 'Red-bellied Snake', 'Smooth Greensnake'],
@@ -15,10 +24,26 @@ labels = {
15
  'amphibia': ['American Bullfrog', 'American Toad', 'Green Frog', 'Northern Leopard Frog']
16
  }
17
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18
  hierarchical_models = {}
19
  for label in labels:
20
- model_url = f"https://huggingface.co/spaces/sowbaranika13/ohio_space/resolve/main/inceptionv3_{label}.h5"
21
- hierarchical_models[label] = load_model(model_url)
22
 
23
  def load_and_preprocess_image(image, target_size=(299, 299)):
24
  img_array = img_to_array(image)
@@ -37,11 +62,10 @@ def predict(image):
37
  results['class'] = class_label
38
 
39
  # Predict species level
40
- if class_label in ['serpentes', 'mammalia', 'aves', 'amphibia']:
41
- species_preds = hierarchical_models[class_label].predict(image_array)
42
- species_idx = np.argmax(species_preds)
43
- species_label = labels[class_label][species_idx]
44
- results['species'] = species_label
45
 
46
  return results
47
 
 
1
  import gradio as gr
2
  import tensorflow as tf
3
  from tensorflow.keras.models import load_model
4
+ from tensorflow.keras.preprocessing.image import img_to_array
5
  import numpy as np
6
  import os
7
  import requests
8
 
9
+ # Define model URLs
10
+ model_urls = {
11
+ 'class': 'https://huggingface.co/spaces/sowbaranika13/ohio_space/resolve/main/inceptionv3_class.h5',
12
+ 'serpentes': 'https://huggingface.co/spaces/sowbaranika13/ohio_space/resolve/main/inceptionv3_serpentes.h5',
13
+ 'mammalia': 'https://huggingface.co/spaces/sowbaranika13/ohio_space/resolve/main/inceptionv3_mammalia.h5',
14
+ 'aves': 'https://huggingface.co/spaces/sowbaranika13/ohio_space/resolve/main/inceptionv3_aves.h5',
15
+ 'amphibia': 'https://huggingface.co/spaces/sowbaranika13/ohio_space/resolve/main/inceptionv3_amphibia.h5'
16
+ }
17
+
18
+ # Labels for class and species
19
  labels = {
20
  'class': ['amphibia', 'aves', 'invertebrates', 'lacertilia', 'mammalia', 'serpentes', 'testudines'],
21
  'serpentes': ["Butler's Gartersnake", "Dekay's Brownsnake", 'Eastern Gartersnake', 'Eastern Hog-nosed snake', 'Eastern Massasauga', 'Eastern Milksnake', 'Eastern Racer Snake', 'Eastern Ribbonsnake', 'Gray Ratsnake', "Kirtland's Snake", 'Northern Watersnake', 'Plains Gartersnake', 'Red-bellied Snake', 'Smooth Greensnake'],
 
24
  'amphibia': ['American Bullfrog', 'American Toad', 'Green Frog', 'Northern Leopard Frog']
25
  }
26
 
27
+ # Download and save models locally
28
+ def download_model(url, model_path):
29
+ if not os.path.exists(model_path):
30
+ response = requests.get(url)
31
+ with open(model_path, 'wb') as file:
32
+ file.write(response.content)
33
+
34
+ # Ensure the models directory exists
35
+ os.makedirs("models", exist_ok=True)
36
+
37
+ # Download all models
38
+ for label, url in model_urls.items():
39
+ model_path = os.path.join("models", f"inceptionv3_{label}.h5")
40
+ download_model(url, model_path)
41
+
42
+ # Load models
43
  hierarchical_models = {}
44
  for label in labels:
45
+ model_path = os.path.join("models", f"inceptionv3_{label}.h5")
46
+ hierarchical_models[label] = load_model(model_path)
47
 
48
  def load_and_preprocess_image(image, target_size=(299, 299)):
49
  img_array = img_to_array(image)
 
62
  results['class'] = class_label
63
 
64
  # Predict species level
65
+ species_preds = hierarchical_models[class_label].predict(image_array)
66
+ species_idx = np.argmax(species_preds)
67
+ species_label = labels[class_label][species_idx]
68
+ results['species'] = species_label
 
69
 
70
  return results
71