manasch commited on
Commit
8865845
1 Parent(s): 150837c

add ngrok endpoint as input

Browse files
Files changed (2) hide show
  1. app.py +19 -8
  2. lib/audio_generation.py +4 -6
app.py CHANGED
@@ -19,7 +19,7 @@ class AudioPalette:
19
  self.image_captioning = ImageCaptioning()
20
  self.audio_generation = AudioGeneration()
21
 
22
- def generate(self, input_image: PIL.Image.Image):
23
  pace = self.pace_model.predict(input_image)
24
  print("Pace Prediction Done")
25
 
@@ -29,7 +29,7 @@ class AudioPalette:
29
  generated_text = generated_text if generated_text is not None else ""
30
  prompt = f"Generate a soundtrack for {generated_text} with {pace} beats and the instrument of choice is the guitar, High quality"
31
 
32
- audio_file = self.audio_generation.generate(prompt)
33
  print("Audio Generation Done")
34
 
35
  outputs = [prompt, pace, generated_text, audio_file]
@@ -41,12 +41,23 @@ def main():
41
 
42
  demo = gr.Interface(
43
  fn=model.generate,
44
- inputs=gr.Image(
45
- type="pil",
46
- label="Upload an image",
47
- show_label=True,
48
- container=True
49
- ),
 
 
 
 
 
 
 
 
 
 
 
50
  outputs=[
51
  gr.Textbox(
52
  lines=1,
 
19
  self.image_captioning = ImageCaptioning()
20
  self.audio_generation = AudioGeneration()
21
 
22
+ def generate(self, input_image: PIL.Image.Image, ngrok_endpoint: str):
23
  pace = self.pace_model.predict(input_image)
24
  print("Pace Prediction Done")
25
 
 
29
  generated_text = generated_text if generated_text is not None else ""
30
  prompt = f"Generate a soundtrack for {generated_text} with {pace} beats and the instrument of choice is the guitar, High quality"
31
 
32
+ audio_file = self.audio_generation.generate(prompt, ngrok_endpoint)
33
  print("Audio Generation Done")
34
 
35
  outputs = [prompt, pace, generated_text, audio_file]
 
41
 
42
  demo = gr.Interface(
43
  fn=model.generate,
44
+ inputs=[
45
+ gr.Image(
46
+ type="pil",
47
+ label="Upload an image",
48
+ show_label=True,
49
+ container=True
50
+ ),
51
+ gr.Textbox(
52
+ lines=1,
53
+ placeholder="ngrok endpoint",
54
+ label="colab endpoint",
55
+ show_label=True,
56
+ container=True,
57
+ type="text",
58
+ visible=True
59
+ )
60
+ ],
61
  outputs=[
62
  gr.Textbox(
63
  lines=1,
lib/audio_generation.py CHANGED
@@ -6,13 +6,10 @@ import requests
6
 
7
  class AudioGeneration:
8
  def __init__(self):
9
- self.endpoint = os.environ["colab_ngrok_api_endpoint"]
10
- self.request_single_endpoint = self.endpoint + "single"
11
- self.download_endpoint = self.endpoint + "download"
12
  self.session = requests.session()
13
 
14
  def request_single(self, prompt: str):
15
- response = self.session.post(self.request_single_endpoint, json={
16
  "caption": prompt
17
  })
18
 
@@ -23,7 +20,7 @@ class AudioGeneration:
23
  pass
24
 
25
  def request_download(self, file_path: str):
26
- response = self.session.post(self.download_endpoint, json={
27
  "file_path": file_path
28
  })
29
 
@@ -34,7 +31,8 @@ class AudioGeneration:
34
 
35
  return audio_file_path
36
 
37
- def generate(self, prompt: typing.Union[str, typing.List[str]]):
 
38
  if isinstance(prompt, str):
39
  stored_file_path = self.request_single(prompt)
40
  audio_file = self.request_download(stored_file_path)
 
6
 
7
  class AudioGeneration:
8
  def __init__(self):
 
 
 
9
  self.session = requests.session()
10
 
11
  def request_single(self, prompt: str):
12
+ response = self.session.post(self.endpoint + "single", json={
13
  "caption": prompt
14
  })
15
 
 
20
  pass
21
 
22
  def request_download(self, file_path: str):
23
+ response = self.session.post(self.endpoint + "download", json={
24
  "file_path": file_path
25
  })
26
 
 
31
 
32
  return audio_file_path
33
 
34
+ def generate(self, prompt: typing.Union[str, typing.List[str]], endpoint: str):
35
+ self.endpoint = endpoint
36
  if isinstance(prompt, str):
37
  stored_file_path = self.request_single(prompt)
38
  audio_file = self.request_download(stored_file_path)