nbiish commited on
Commit
c96c440
β€’
1 Parent(s): 0e69072

Refactor app.py to use gemini client for bird feather detection

Browse files
Files changed (4) hide show
  1. app.py +32 -29
  2. package.txt +3 -8
  3. list.csv β†’ protected-birds.txt +0 -1
  4. requirements.txt +5 -9
app.py CHANGED
@@ -1,42 +1,44 @@
1
  import gradio as gr
2
- import pandas as pd
3
- import numpy as np
4
- import os
5
- from PIL import Image
6
- import requests
7
- from io import BytesIO
8
- from imageai.Prediction import ImagePrediction
9
  import csv
 
 
 
10
 
11
- # Load the protected-birds.csv file
12
- protected_birds = pd.read_csv('protected-birds.csv')
13
- protected_birds = protected_birds.dropna()
14
- protected_birds = protected_birds['Common Name'].str.lower().values
 
 
15
 
16
- # Load the image ai model
17
- model = ImagePrediction()
18
- model.setModelTypeAsResNet()
19
- model.setModelPath("resnet50_weights_tf_dim_ordering_tf_kernels.h5")
20
- model.loadModel()
21
 
22
  # Function to check if the feather is from a protected bird
23
- def check_feather(image):
24
  # Load the image
25
- response = requests.get(image)
26
- img = Image.open(BytesIO(response.content))
27
- img.save('image.jpg')
28
 
29
  # Make a prediction
30
- model_predictions, probabilities = model.predictImage('image.jpg', result_count=5)
 
 
 
 
 
 
 
 
31
 
32
  # Check if the feather is from a protected bird
33
  protected = False
34
  protected_birds_list = []
35
- for bird in protected_birds:
36
- for prediction in model_predictions:
37
- if bird in prediction.lower():
38
- protected = True
39
- protected_birds_list.append([bird, prediction, probabilities[model_predictions.index(prediction)]])
40
 
41
  if protected:
42
  # Display the results
@@ -48,10 +50,11 @@ def check_feather(image):
48
  # Prompt user to make a report
49
  gr.Interface(fn=lambda: "If confident, report to: https://www.fws.gov/contact-us", inputs=user_verification, outputs="text").launch()
50
  else:
51
- gr.Interface(fn=lambda: "No protected bird detected", inputs=image, outputs="text").launch()
52
 
53
  return protected, protected_birds_list
54
 
55
  # Gradio interface
56
- image = gr.inputs.Image()
57
- gr.Interface(fn=check_feather, inputs=image, outputs="text").launch()
 
 
1
  import gradio as gr
 
 
 
 
 
 
 
2
  import csv
3
+ import gemini
4
+ import base64
5
+ import httpx
6
 
7
+ # Load the protected-birds.txt file into a dictionary
8
+ protected_birds_dict = {}
9
+ with open('protected-birds.txt', 'r') as txtfile:
10
+ for line in txtfile:
11
+ bird, confidence = line.strip().split(',')
12
+ protected_birds_dict[bird.lower()] = float(confidence)
13
 
14
+ # Initialize the gemini client
15
+ def init_client(api_key):
16
+ return gemini.Client(api_key)
 
 
17
 
18
  # Function to check if the feather is from a protected bird
19
+ def check_feather(image_url, api_key):
20
  # Load the image
21
+ image_media_type = "image/jpeg"
22
+ image_data = base64.b64encode(httpx.get(image_url).content).decode("utf-8")
 
23
 
24
  # Make a prediction
25
+ client = init_client(api_key)
26
+ response = client.predict(
27
+ model_id="text-bison-001",
28
+ inputs={
29
+ "image": image_data,
30
+ "question": "Describe this image."
31
+ }
32
+ )
33
+ model_predictions = response["candidates"][0]["output"]
34
 
35
  # Check if the feather is from a protected bird
36
  protected = False
37
  protected_birds_list = []
38
+ for bird, confidence in protected_birds_dict.items():
39
+ if bird in model_predictions.lower():
40
+ protected = True
41
+ protected_birds_list.append([bird, confidence, model_predictions])
 
42
 
43
  if protected:
44
  # Display the results
 
50
  # Prompt user to make a report
51
  gr.Interface(fn=lambda: "If confident, report to: https://www.fws.gov/contact-us", inputs=user_verification, outputs="text").launch()
52
  else:
53
+ gr.Interface(fn=lambda: "No protected bird detected", inputs=image_url, outputs="text").launch()
54
 
55
  return protected, protected_birds_list
56
 
57
  # Gradio interface
58
+ image_url = gr.inputs.Textbox()
59
+ api_key = gr.inputs.Textbox(label="Google API Key")
60
+ gr.Interface(fn=check_feather, inputs=[image_url, api_key], outputs="text").launch()
package.txt CHANGED
@@ -1,8 +1,3 @@
1
- gradio
2
- pandas
3
- numpy
4
- Pillow
5
- requests
6
- imageai
7
- tensorflow==2.4.0
8
- h5py==2.10.0
 
1
+ transformers==4.26.0
2
+ datasets==2.4.0
3
+ sentencepiece==0.1.96
 
 
 
 
 
list.csv β†’ protected-birds.txt RENAMED
@@ -1,4 +1,3 @@
1
- ENGLISH_NAME
2
  Black-bellied Whistling-Duck
3
  West Indian Whistling-Duck
4
  Fulvous Whistling-Duck
 
 
1
  Black-bellied Whistling-Duck
2
  West Indian Whistling-Duck
3
  Fulvous Whistling-Duck
requirements.txt CHANGED
@@ -1,10 +1,6 @@
1
  gradio
2
- pandas
3
- numpy
4
- Pillow
5
- requests
6
- imageai==2.1.6
7
- tensorflow>=2.4.0
8
- pytest
9
- mock
10
- coverage
 
1
  gradio
2
+ anthropic
3
+ httpx
4
+ transformers==4.26.0
5
+ datasets==2.4.0
6
+ sentencepiece==0.1.96