manasch commited on
Commit
5f6a9dc
1 Parent(s): 27f3beb

update image input to pillow

Browse files
Files changed (4) hide show
  1. app.py +10 -8
  2. lib/image_captioning.py +9 -7
  3. lib/pace_model.py +4 -3
  4. requirements.txt +1 -0
app.py CHANGED
@@ -3,6 +3,8 @@ from pathlib import Path
3
  import numpy as np
4
  import gradio as gr
5
 
 
 
6
  from lib.image_captioning import ImageCaptioning
7
  from lib.pace_model import PaceModel
8
 
@@ -15,18 +17,18 @@ class AudioPalette:
15
  self.pace_model = PaceModel(height, width, channels, resnet50_tf_model_weights_path, pace_model_weights_path)
16
  self.image_captioning = ImageCaptioning()
17
 
18
- def generate(self, input_image_path):
19
- generated_text = self.image_captioning.query(input_image_path)[0].get("generated_text")
20
- return self.pace_model.predict(input_image_path) + " - " + generated_text
 
21
 
22
  def main():
23
  model = AudioPalette()
24
 
25
- gr.themes.Monochrome()
26
  demo = gr.Interface(
27
  fn=model.generate,
28
  inputs=gr.Image(
29
- type="filepath",
30
  label="Upload an image",
31
  show_label=True,
32
  container=True
@@ -34,15 +36,15 @@ def main():
34
  outputs=gr.Textbox(
35
  lines=1,
36
  placeholder="Pace of the image and the caption",
37
- label="Pace of the image",
38
  show_label=True,
39
  container=True,
40
  type="text"
41
  ),
42
  cache_examples=False,
43
  live=False,
44
- title="Predict Pace",
45
- description="Provide an image to determine the pace of the image",
46
  )
47
 
48
  demo.queue().launch()
 
3
  import numpy as np
4
  import gradio as gr
5
 
6
+ import PIL
7
+
8
  from lib.image_captioning import ImageCaptioning
9
  from lib.pace_model import PaceModel
10
 
 
17
  self.pace_model = PaceModel(height, width, channels, resnet50_tf_model_weights_path, pace_model_weights_path)
18
  self.image_captioning = ImageCaptioning()
19
 
20
+ def generate(self, input_image: PIL.Image.Image):
21
+ generated_text = self.image_captioning.query(input_image)[0].get("generated_text")
22
+ pace = self.pace_model.predict(input_image)
23
+ return pace + (" - " + generated_text if generated_text is not None else "")
24
 
25
  def main():
26
  model = AudioPalette()
27
 
 
28
  demo = gr.Interface(
29
  fn=model.generate,
30
  inputs=gr.Image(
31
+ type="pil",
32
  label="Upload an image",
33
  show_label=True,
34
  container=True
 
36
  outputs=gr.Textbox(
37
  lines=1,
38
  placeholder="Pace of the image and the caption",
39
+ label="Caption and Pace",
40
  show_label=True,
41
  container=True,
42
  type="text"
43
  ),
44
  cache_examples=False,
45
  live=False,
46
+ title="Audio Palette",
47
+ description="Provide an image to generate appropriate background soundtrack",
48
  )
49
 
50
  demo.queue().launch()
lib/image_captioning.py CHANGED
@@ -1,6 +1,8 @@
 
1
  import os
2
 
3
  import cv2
 
4
  import requests
5
 
6
  class ImageCaptioning:
@@ -12,16 +14,16 @@ class ImageCaptioning:
12
  self.org_token = os.environ["auth_token"]
13
  self.headers = { "Authorization": f"Bearer {self.org_token}" }
14
 
15
- def read_image(self, image_path):
16
- with open(image_path, "rb") as f:
17
- data = f.read()
18
-
19
- return data
20
 
21
- def query(self, image_path: str):
22
  response = requests.post(
23
  self.api_endpoint,
24
  headers=self.headers,
25
- data=self.read_image(image_path)
26
  )
 
27
  return response.json()
 
1
+ import io
2
  import os
3
 
4
  import cv2
5
+ import PIL
6
  import requests
7
 
8
  class ImageCaptioning:
 
14
  self.org_token = os.environ["auth_token"]
15
  self.headers = { "Authorization": f"Bearer {self.org_token}" }
16
 
17
+ def convert_to_bytes(self, image: PIL.Image.Image):
18
+ data = io.BytesIO()
19
+ image.save(data, format="PNG")
20
+ return data.getvalue()
 
21
 
22
+ def query(self, image: PIL.Image.Image):
23
  response = requests.post(
24
  self.api_endpoint,
25
  headers=self.headers,
26
+ data=self.convert_to_bytes(image)
27
  )
28
+ print(response.json())
29
  return response.json()
lib/pace_model.py CHANGED
@@ -3,6 +3,7 @@ import tensorflow as tf
3
 
4
  import cv2
5
  import keras
 
6
  from keras import Sequential
7
  from keras.applications.resnet50 import ResNet50
8
  from keras.layers import Flatten, Dense
@@ -45,9 +46,9 @@ class PaceModel:
45
 
46
  self.resnet_model.load_weights(self.pace_model_weights_path)
47
 
48
- def predict(self, input_image_path: str):
49
- input_image = cv2.imread(input_image_path)
50
- resized_image = cv2.resize(input_image, (self.height, self.width))
51
  image = np.expand_dims(resized_image, axis=0)
52
 
53
  prediction = self.resnet_model.predict(image)
 
3
 
4
  import cv2
5
  import keras
6
+ import PIL
7
  from keras import Sequential
8
  from keras.applications.resnet50 import ResNet50
9
  from keras.layers import Flatten, Dense
 
46
 
47
  self.resnet_model.load_weights(self.pace_model_weights_path)
48
 
49
+ def predict(self, input_image: PIL.Image.Image):
50
+ np_image = np.array(input_image)
51
+ resized_image = cv2.resize(np_image, (self.height, self.width))
52
  image = np.expand_dims(resized_image, axis=0)
53
 
54
  prediction = self.resnet_model.predict(image)
requirements.txt CHANGED
@@ -2,4 +2,5 @@ gradio
2
  keras
3
  numpy
4
  opencv-python
 
5
  tensorflow
 
2
  keras
3
  numpy
4
  opencv-python
5
+ pillow
6
  tensorflow