zzk1st commited on
Commit
5bd33c2
1 Parent(s): 1e9b2c6

Multiple user API keys

Browse files
Files changed (6) hide show
  1. config.yaml +1 -1
  2. pipeline.py +34 -32
  3. share_btn.py +2 -2
  4. ui_client.py +12 -8
  5. utils.py +1 -2
  6. wavjourney_cli.py +4 -1
config.yaml CHANGED
@@ -18,4 +18,4 @@ Voice-Parser:
18
 
19
  Service-Port: 5000
20
 
21
- OpenAI-Key: ''
 
18
 
19
  Service-Port: 5000
20
 
21
+ OpenAI-Key: ''
pipeline.py CHANGED
@@ -4,7 +4,6 @@ from string import Template
4
  import openai
5
  import re
6
  import glob
7
- from utils import get_key
8
  import pickle
9
  import time
10
  import json5
@@ -26,28 +25,33 @@ if USE_OPENAI_CACHE:
26
  with open(cache_file, 'rb') as file:
27
  openai_cache.append(pickle.load(file))
28
 
29
- openai.api_key = get_key()
30
 
31
- def chat_with_gpt(prompt):
32
  if USE_OPENAI_CACHE:
33
  filtered_object = list(filter(lambda x: x['prompt'] == prompt, openai_cache))
34
  if len(filtered_object) > 0:
35
  response = filtered_object[0]['response']
36
  return response
37
- chat = openai.ChatCompletion.create(
38
- # model="gpt-3.5-turbo",
39
- model="gpt-4",
40
- messages=[
41
- {
42
- "role": "system",
43
- "content": "You are a helpful assistant."
44
- },
45
- {
46
- "role": "user",
47
- "content": prompt
48
- }
49
- ]
50
- )
 
 
 
 
 
 
51
  if USE_OPENAI_CACHE:
52
  cache_obj = {
53
  'prompt': prompt,
@@ -120,10 +124,10 @@ def init_session(session_id=''):
120
  return session_id
121
 
122
  @retry(stop_max_attempt_number=3)
123
- def input_text_to_json_script_with_retry(complete_prompt_path):
124
  print(" trying ...")
125
  complete_prompt = get_file_content(complete_prompt_path)
126
- json_response = try_extract_content_from_quotes(chat_with_gpt(complete_prompt))
127
  json_data = json5.loads(json_response)
128
 
129
  try:
@@ -138,22 +142,20 @@ def input_text_to_json_script_with_retry(complete_prompt_path):
138
  return json_response
139
 
140
  # Step 1: input_text to json
141
- def input_text_to_json_script(input_text, output_path):
142
  print('Step 1: Writing audio script with LLM ...')
143
  input_text = maybe_get_content_from_file(input_text)
144
  text_to_audio_script_prompt = get_file_content('prompts/text_to_json.prompt')
145
  prompt = f'{text_to_audio_script_prompt}\n\nInput text: {input_text}\n\nScript:\n'
146
  complete_prompt_path = output_path / 'complete_input_text_to_audio_script.prompt'
147
  write_to_file(complete_prompt_path, prompt)
148
- audio_script_response = input_text_to_json_script_with_retry(complete_prompt_path)
149
  generated_audio_script_filename = output_path / 'audio_script.json'
150
  write_to_file(generated_audio_script_filename, audio_script_response)
151
  return audio_script_response
152
 
153
  # Step 2: json to char-voice map
154
- def json_script_to_char_voice_map(json_script, voices, output_path):
155
- def create_complete_char_voice_map(char_voice_map):
156
- return
157
  print('Step 2: Parsing character voice with LLM...')
158
  json_script_content = maybe_get_content_from_file(json_script)
159
  prompt = get_file_content('prompts/audio_script_to_character_voice_map.prompt')
@@ -161,7 +163,7 @@ def json_script_to_char_voice_map(json_script, voices, output_path):
161
  prompt = Template(prompt).substitute(voice_and_desc=presets_str)
162
  prompt = f"{prompt}\n\nAudio script:\n'''\n{json_script_content}\n'''\n\noutput:\n"
163
  write_to_file(output_path / 'complete_audio_script_to_char_voice_map.prompt', prompt)
164
- char_voice_map_response = try_extract_content_from_quotes(chat_with_gpt(prompt))
165
  char_voice_map = json5.loads(char_voice_map_response)
166
  # enrich char_voice_map with voice preset metadata
167
  complete_char_voice_map = {c: voices[char_voice_map[c]] for c in char_voice_map}
@@ -188,19 +190,19 @@ def audio_code_gen_to_result(audio_gen_code_path):
188
  os.system(f'python {audio_gen_code_filename}')
189
 
190
  # Function call used by Gradio: input_text to json
191
- def generate_json_file(session_id, input_text):
192
  output_path = utils.get_session_path(session_id)
193
  # Step 1
194
- return input_text_to_json_script(input_text, output_path)
195
 
196
  # Function call used by Gradio: json to result wav
197
- def generate_audio(session_id, json_script):
198
  output_path = utils.get_session_path(session_id)
199
  output_audio_path = utils.get_session_audio_path(session_id)
200
  voices = voice_presets.get_merged_voice_presets(session_id)
201
 
202
  # Step 2
203
- char_voice_map = json_script_to_char_voice_map(json_script, voices, output_path)
204
  # Step 3
205
  json_script_filename = output_path / 'audio_script.json'
206
  char_voice_map_filename = output_path / 'character_voice_map.json'
@@ -214,6 +216,6 @@ def generate_audio(session_id, json_script):
214
  return result_wav_filename, char_voice_map
215
 
216
  # Convenient function call used by wavjourney_cli
217
- def full_steps(session_id, input_text):
218
- json_script = generate_json_file(session_id, input_text)
219
- return generate_audio(session_id, json_script)
 
4
  import openai
5
  import re
6
  import glob
 
7
  import pickle
8
  import time
9
  import json5
 
25
  with open(cache_file, 'rb') as file:
26
  openai_cache.append(pickle.load(file))
27
 
 
28
 
29
+ def chat_with_gpt(prompt, api_key):
30
  if USE_OPENAI_CACHE:
31
  filtered_object = list(filter(lambda x: x['prompt'] == prompt, openai_cache))
32
  if len(filtered_object) > 0:
33
  response = filtered_object[0]['response']
34
  return response
35
+
36
+ try:
37
+ openai.api_key = api_key
38
+ chat = openai.ChatCompletion.create(
39
+ # model="gpt-3.5-turbo",
40
+ model="gpt-4",
41
+ messages=[
42
+ {
43
+ "role": "system",
44
+ "content": "You are a helpful assistant."
45
+ },
46
+ {
47
+ "role": "user",
48
+ "content": prompt
49
+ }
50
+ ]
51
+ )
52
+ finally:
53
+ openai.api_key = ''
54
+
55
  if USE_OPENAI_CACHE:
56
  cache_obj = {
57
  'prompt': prompt,
 
124
  return session_id
125
 
126
  @retry(stop_max_attempt_number=3)
127
+ def input_text_to_json_script_with_retry(complete_prompt_path, api_key):
128
  print(" trying ...")
129
  complete_prompt = get_file_content(complete_prompt_path)
130
+ json_response = try_extract_content_from_quotes(chat_with_gpt(complete_prompt, api_key))
131
  json_data = json5.loads(json_response)
132
 
133
  try:
 
142
  return json_response
143
 
144
  # Step 1: input_text to json
145
+ def input_text_to_json_script(input_text, output_path, api_key):
146
  print('Step 1: Writing audio script with LLM ...')
147
  input_text = maybe_get_content_from_file(input_text)
148
  text_to_audio_script_prompt = get_file_content('prompts/text_to_json.prompt')
149
  prompt = f'{text_to_audio_script_prompt}\n\nInput text: {input_text}\n\nScript:\n'
150
  complete_prompt_path = output_path / 'complete_input_text_to_audio_script.prompt'
151
  write_to_file(complete_prompt_path, prompt)
152
+ audio_script_response = input_text_to_json_script_with_retry(complete_prompt_path, api_key)
153
  generated_audio_script_filename = output_path / 'audio_script.json'
154
  write_to_file(generated_audio_script_filename, audio_script_response)
155
  return audio_script_response
156
 
157
  # Step 2: json to char-voice map
158
+ def json_script_to_char_voice_map(json_script, voices, output_path, api_key):
 
 
159
  print('Step 2: Parsing character voice with LLM...')
160
  json_script_content = maybe_get_content_from_file(json_script)
161
  prompt = get_file_content('prompts/audio_script_to_character_voice_map.prompt')
 
163
  prompt = Template(prompt).substitute(voice_and_desc=presets_str)
164
  prompt = f"{prompt}\n\nAudio script:\n'''\n{json_script_content}\n'''\n\noutput:\n"
165
  write_to_file(output_path / 'complete_audio_script_to_char_voice_map.prompt', prompt)
166
+ char_voice_map_response = try_extract_content_from_quotes(chat_with_gpt(prompt, api_key))
167
  char_voice_map = json5.loads(char_voice_map_response)
168
  # enrich char_voice_map with voice preset metadata
169
  complete_char_voice_map = {c: voices[char_voice_map[c]] for c in char_voice_map}
 
190
  os.system(f'python {audio_gen_code_filename}')
191
 
192
  # Function call used by Gradio: input_text to json
193
+ def generate_json_file(session_id, input_text, api_key):
194
  output_path = utils.get_session_path(session_id)
195
  # Step 1
196
+ return input_text_to_json_script(input_text, output_path, api_key)
197
 
198
  # Function call used by Gradio: json to result wav
199
+ def generate_audio(session_id, json_script, api_key):
200
  output_path = utils.get_session_path(session_id)
201
  output_audio_path = utils.get_session_audio_path(session_id)
202
  voices = voice_presets.get_merged_voice_presets(session_id)
203
 
204
  # Step 2
205
+ char_voice_map = json_script_to_char_voice_map(json_script, voices, output_path, api_key)
206
  # Step 3
207
  json_script_filename = output_path / 'audio_script.json'
208
  char_voice_map_filename = output_path / 'character_voice_map.json'
 
216
  return result_wav_filename, char_voice_map
217
 
218
  # Convenient function call used by wavjourney_cli
219
+ def full_steps(session_id, input_text, api_key):
220
+ json_script = generate_json_file(session_id, input_text, api_key)
221
+ return generate_audio(session_id, json_script, api_key)
share_btn.py CHANGED
@@ -26,7 +26,7 @@ share_js = """async () => {
26
  const res = await fetch(videoEl.src);
27
  const blob = await res.blob();
28
  const videoId = Date.now() % 200;
29
- const fileName = `sd-perception-${{videoId}}.mp4`;
30
  return new File([blob], fileName, { type: 'video/mp4' });
31
  }
32
 
@@ -40,7 +40,7 @@ share_js = """async () => {
40
  });
41
  }
42
  const gradioEl = document.querySelector("gradio-app").shadowRoot || document.querySelector('body > gradio-app');
43
- const inputPromptEl = gradioEl.querySelector('#prompt-in input').value;
44
  const outputVideoEl = gradioEl.querySelector('#output-video video');
45
 
46
  let titleTxt = `WavJourney: ${inputPromptEl}`;
 
26
  const res = await fetch(videoEl.src);
27
  const blob = await res.blob();
28
  const videoId = Date.now() % 200;
29
+ const fileName = `sd-perception-${videoId}.mp4`;
30
  return new File([blob], fileName, { type: 'video/mp4' });
31
  }
32
 
 
40
  });
41
  }
42
  const gradioEl = document.querySelector("gradio-app").shadowRoot || document.querySelector('body > gradio-app');
43
+ const inputPromptEl = gradioEl.querySelector('#prompt-in textarea').value;
44
  const outputVideoEl = gradioEl.querySelector('#output-video video');
45
 
46
  let titleTxt = `WavJourney: ${inputPromptEl}`;
ui_client.py CHANGED
@@ -1,7 +1,7 @@
1
  import shutil
2
  import json5
 
3
 
4
- import openai
5
  import gradio as gr
6
  from tabulate import tabulate
7
 
@@ -44,11 +44,13 @@ def convert_char_voice_map_to_md(char_voice_map):
44
  def generate_script_fn(instruction, _state: gr.State):
45
  try:
46
  session_id = _state['session_id']
47
- json_script = generate_json_file(session_id, instruction)
 
48
  table_text = convert_json_to_md(json_script)
49
  except Exception as e:
50
  gr.Warning(str(e))
51
  print(f"Generating script error: {str(e)}")
 
52
  return [
53
  None,
54
  _state,
@@ -89,9 +91,8 @@ def generate_audio_fn(state):
89
  ]
90
  except Exception as e:
91
  print(f"Generation audio error: {str(e)}")
 
92
  gr.Warning(str(e))
93
- # For debugging, uncomment the line below
94
- #raise e
95
 
96
  return [
97
  None,
@@ -172,8 +173,8 @@ def get_system_voice_presets():
172
  return data
173
 
174
 
175
- def set_openai_key(key):
176
- openai.api_key = key
177
  return key
178
 
179
 
@@ -191,7 +192,10 @@ def add_voice_preset(vp_id, vp_desc, file, ui_state, added_voice_preset):
191
  add_session_voice_preset(vp_id, vp_desc, file_path, session_id)
192
  added_voice_preset['count'] = count + 1
193
  except Exception as exception:
 
 
194
  gr.Warning(str(exception))
 
195
  # After added
196
  dataframe = get_voice_preset_to_list(ui_state)
197
  df_visible = gr.Dataframe.update(visible=True)
@@ -379,7 +383,7 @@ with gr.Blocks(css=css) as interface:
379
 
380
  system_voice_presets = get_system_voice_presets()
381
  # State
382
- ui_state = gr.State(value={'session_id': pipeline.init_session()})
383
  selected_voice_presets = gr.State(value={'selected_voice_preset': None})
384
  added_voice_preset_state = gr.State(value={'added_file': None, 'count': 0})
385
  # UI Component
@@ -461,7 +465,7 @@ with gr.Blocks(css=css) as interface:
461
  )
462
 
463
  # events
464
- key_text_input.change(fn=set_openai_key, inputs=[key_text_input], outputs=[key_text_input])
465
  text_input.change(fn=textbox_listener, inputs=[text_input], outputs=[generate_script_btn])
466
  generate_audio_btn.click(
467
  fn=generate_audio_fn,
 
1
  import shutil
2
  import json5
3
+ import traceback
4
 
 
5
  import gradio as gr
6
  from tabulate import tabulate
7
 
 
44
  def generate_script_fn(instruction, _state: gr.State):
45
  try:
46
  session_id = _state['session_id']
47
+ api_key = _state['api_key']
48
+ json_script = generate_json_file(session_id, instruction, api_key)
49
  table_text = convert_json_to_md(json_script)
50
  except Exception as e:
51
  gr.Warning(str(e))
52
  print(f"Generating script error: {str(e)}")
53
+ traceback.print_exc()
54
  return [
55
  None,
56
  _state,
 
91
  ]
92
  except Exception as e:
93
  print(f"Generation audio error: {str(e)}")
94
+ traceback.print_exc()
95
  gr.Warning(str(e))
 
 
96
 
97
  return [
98
  None,
 
173
  return data
174
 
175
 
176
+ def set_openai_key(key, _state):
177
+ _state['api_key'] = key
178
  return key
179
 
180
 
 
192
  add_session_voice_preset(vp_id, vp_desc, file_path, session_id)
193
  added_voice_preset['count'] = count + 1
194
  except Exception as exception:
195
+ print(exception)
196
+ traceback.print_exc()
197
  gr.Warning(str(exception))
198
+
199
  # After added
200
  dataframe = get_voice_preset_to_list(ui_state)
201
  df_visible = gr.Dataframe.update(visible=True)
 
383
 
384
  system_voice_presets = get_system_voice_presets()
385
  # State
386
+ ui_state = gr.State(value={'session_id': pipeline.init_session(), 'api_key': ''})
387
  selected_voice_presets = gr.State(value={'selected_voice_preset': None})
388
  added_voice_preset_state = gr.State(value={'added_file': None, 'count': 0})
389
  # UI Component
 
465
  )
466
 
467
  # events
468
+ key_text_input.change(fn=set_openai_key, inputs=[key_text_input, ui_state], outputs=[key_text_input])
469
  text_input.change(fn=textbox_listener, inputs=[text_input], outputs=[generate_script_btn])
470
  generate_audio_btn.click(
471
  fn=generate_audio_fn,
utils.py CHANGED
@@ -62,6 +62,5 @@ def fade(audio_data, fade_duration=2, sr=32000):
62
  def get_key(config='config.yaml'):
63
  with open('config.yaml', 'r') as file:
64
  config = yaml.safe_load(file)
65
- openai_key = config['OpenAI-Key']
66
- return openai_key
67
 
 
62
  def get_key(config='config.yaml'):
63
  with open('config.yaml', 'r') as file:
64
  config = yaml.safe_load(file)
65
+ return config['OpenAI-Key'] if 'OpenAI-Key' in config else None
 
66
 
wavjourney_cli.py CHANGED
@@ -1,12 +1,14 @@
1
  import time
2
  import argparse
3
 
 
4
  import pipeline
5
 
6
  parser = argparse.ArgumentParser()
7
  parser.add_argument('-f', '--full', action='store_true', help='Go through the full process')
8
  parser.add_argument('--input-text', type=str, default='', help='input text or text file')
9
  parser.add_argument('--session-id', type=str, default='', help='session id, if set to empty, system will allocate an id')
 
10
  args = parser.parse_args()
11
 
12
  if args.full:
@@ -14,10 +16,11 @@ if args.full:
14
 
15
  start_time = time.time()
16
  session_id = pipeline.init_session(args.session_id)
 
17
 
18
  print(f"Session {session_id} is created.")
19
 
20
- pipeline.full_steps(session_id, input_text)
21
  end_time = time.time()
22
 
23
  print(f"WavJourney took {end_time - start_time:.2f} seconds to complete.")
 
1
  import time
2
  import argparse
3
 
4
+ import utils
5
  import pipeline
6
 
7
  parser = argparse.ArgumentParser()
8
  parser.add_argument('-f', '--full', action='store_true', help='Go through the full process')
9
  parser.add_argument('--input-text', type=str, default='', help='input text or text file')
10
  parser.add_argument('--session-id', type=str, default='', help='session id, if set to empty, system will allocate an id')
11
+ parser.add_argument('--api-key', type=str, default='', help='api key used for GPT-4')
12
  args = parser.parse_args()
13
 
14
  if args.full:
 
16
 
17
  start_time = time.time()
18
  session_id = pipeline.init_session(args.session_id)
19
+ api_key = args.api_key if args.api_key != '' else utils.get_key()
20
 
21
  print(f"Session {session_id} is created.")
22
 
23
+ pipeline.full_steps(session_id, input_text, api_key)
24
  end_time = time.time()
25
 
26
  print(f"WavJourney took {end_time - start_time:.2f} seconds to complete.")