vmoras commited on
Commit
aa938da
·
1 Parent(s): 6e2fd22

Add option to select GPU

Browse files
Files changed (2) hide show
  1. app.py +8 -3
  2. functions.py +18 -14
app.py CHANGED
@@ -21,7 +21,12 @@ with (gr.Blocks() as app):
21
  with gr.Row():
22
  output_audio = gr.Audio(interactive=False, label='Audio', autoplay=True)
23
  with gr.Row():
24
- checkbox = gr.Checkbox(label='Get video', info='Remember that this has a cost')
 
 
 
 
 
25
 
26
  with gr.Column():
27
  with gr.Row():
@@ -59,13 +64,13 @@ with (gr.Blocks() as app):
59
 
60
  button_text.click(
61
  get_answer_text,
62
- [text, chat, user_id, checkbox],
63
  [video, output_audio, chat, text]
64
  )
65
 
66
  button_audio.click(
67
  get_answer_audio,
68
- [audio, chat, user_id, checkbox],
69
  [video, output_audio, chat, audio]
70
  )
71
 
 
21
  with gr.Row():
22
  output_audio = gr.Audio(interactive=False, label='Audio', autoplay=True)
23
  with gr.Row():
24
+ checkbox = gr.Checkbox(
25
+ label='Get video', info='Remember that this has a cost'
26
+ )
27
+ radio = gr.Radio(
28
+ choices=['small', 'big'], value='small', label='GPU', info='Select the size of GPU'
29
+ )
30
 
31
  with gr.Column():
32
  with gr.Row():
 
64
 
65
  button_text.click(
66
  get_answer_text,
67
+ [text, chat, user_id, checkbox, radio],
68
  [video, output_audio, chat, text]
69
  )
70
 
71
  button_audio.click(
72
  get_answer_audio,
73
+ [audio, chat, user_id, checkbox, radio],
74
  [video, output_audio, chat, audio]
75
  )
76
 
functions.py CHANGED
@@ -7,14 +7,6 @@ import gradio as gr
7
  from datetime import datetime
8
  from huggingface_hub import hf_hub_download, HfApi
9
 
10
- API_TOKEN = os.getenv('API_TOKEN')
11
- API_URL = os.getenv('API_URL')
12
-
13
- headers = {
14
- "Authorization": f"Bearer {API_TOKEN}",
15
- "Content-Type": "application/json"
16
- }
17
-
18
 
19
  def get_main_data():
20
  """
@@ -43,10 +35,18 @@ def make_visible():
43
  return gr.Row.update(visible=True), gr.Row.update(visible=True)
44
 
45
 
46
- def _query(payload):
47
  """
48
  Returns the json from a post request. It is done to the BellaAPI
49
  """
 
 
 
 
 
 
 
 
50
  response = requests.post(API_URL, headers=headers, json=payload)
51
  return response.json()
52
 
@@ -71,7 +71,7 @@ def init_chatbot(chatbot: list[tuple[str, str]]):
71
  "date": datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
72
  'get_video': True
73
  }}
74
- output = _query(inputs)
75
 
76
  chatbot.append(('', output['answer']))
77
  _download_media(output['link_media'], 'video')
@@ -79,7 +79,9 @@ def init_chatbot(chatbot: list[tuple[str, str]]):
79
  return 'video.mp4', chatbot, output['user_id']
80
 
81
 
82
- def get_answer_text(question: str, chatbot: list[tuple[str, str]], user_id: str, checkbox: bool):
 
 
83
  """
84
  Gets the answer of the chatbot
85
  """
@@ -87,11 +89,13 @@ def get_answer_text(question: str, chatbot: list[tuple[str, str]], user_id: str,
87
  inputs = {'inputs': {
88
  'text': question, 'user_id': user_id, 'get_video': checkbox
89
  }}
90
- output = _query(inputs)
91
  return _update_elements(question, chatbot, output, checkbox, '')
92
 
93
 
94
- def get_answer_audio(audio_path, chatbot: list[tuple[str, str]], user_id: str, checkbox: bool):
 
 
95
  """
96
  Gets the answer of the chatbot
97
  """
@@ -104,7 +108,7 @@ def get_answer_audio(audio_path, chatbot: list[tuple[str, str]], user_id: str, c
104
  inputs = {'inputs': {
105
  'is_audio': True, 'audio': encoded_audio, 'user_id': user_id, 'get_video': checkbox
106
  }}
107
- output = _query(inputs)
108
 
109
  # Transcription of the audio
110
  question = output['question']
 
7
  from datetime import datetime
8
  from huggingface_hub import hf_hub_download, HfApi
9
 
 
 
 
 
 
 
 
 
10
 
11
  def get_main_data():
12
  """
 
35
  return gr.Row.update(visible=True), gr.Row.update(visible=True)
36
 
37
 
38
+ def _query(payload, size_gpu: str):
39
  """
40
  Returns the json from a post request. It is done to the BellaAPI
41
  """
42
+ API_TOKEN = os.getenv(f'API_TOKEN')
43
+ API_URL = os.getenv(f'API_URL_{size_gpu}')
44
+
45
+ headers = {
46
+ "Authorization": f"Bearer {API_TOKEN}",
47
+ "Content-Type": "application/json"
48
+ }
49
+
50
  response = requests.post(API_URL, headers=headers, json=payload)
51
  return response.json()
52
 
 
71
  "date": datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
72
  'get_video': True
73
  }}
74
+ output = _query(inputs, 'small')
75
 
76
  chatbot.append(('', output['answer']))
77
  _download_media(output['link_media'], 'video')
 
79
  return 'video.mp4', chatbot, output['user_id']
80
 
81
 
82
+ def get_answer_text(
83
+ question: str, chatbot: list[tuple[str, str]], user_id: str, checkbox: bool, size_gpu: str
84
+ ):
85
  """
86
  Gets the answer of the chatbot
87
  """
 
89
  inputs = {'inputs': {
90
  'text': question, 'user_id': user_id, 'get_video': checkbox
91
  }}
92
+ output = _query(inputs, size_gpu)
93
  return _update_elements(question, chatbot, output, checkbox, '')
94
 
95
 
96
+ def get_answer_audio(
97
+ audio_path, chatbot: list[tuple[str, str]], user_id: str, checkbox: bool, size_gpu: str
98
+ ):
99
  """
100
  Gets the answer of the chatbot
101
  """
 
108
  inputs = {'inputs': {
109
  'is_audio': True, 'audio': encoded_audio, 'user_id': user_id, 'get_video': checkbox
110
  }}
111
+ output = _query(inputs, size_gpu)
112
 
113
  # Transcription of the audio
114
  question = output['question']