AppleBotzz commited on
Commit
84c58dc
1 Parent(s): fb40a32

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +9 -6
app.py CHANGED
@@ -14,6 +14,8 @@ claude_models = ["claude-3-opus-20240229", "claude-3-sonnet-20240229", "claude-3
14
  # List of available OpenAI models
15
  openai_models = ["gpt-4", "gpt-4-32k", "gpt-3.5-turbo", "gpt-4-0125-preview", "gpt-4-turbo-preview", "gpt-4-1106-preview", "gpt-4-0613"]
16
 
 
 
17
  both_models = claude_models + openai_models
18
 
19
  def generate_response(endpoint, api_key, model, user_prompt):
@@ -113,15 +115,15 @@ def extract_json(generated_output):
113
  print(e)
114
  return None, None
115
 
116
- def generate_second_response(endpoint, api_key, model, generated_output):
117
  if endpoint in default_urls:
118
  #check api keys as normal
119
  if api_key.startswith("sk-ant-"):
120
  client = Anthropic(api_key=api_key, base_url=endpoint)
121
- system_prompt_path = __file__.replace("app.py", "diffusion.txt")
122
  elif api_key.startswith("sk-"):
123
  client = OpenAI(api_key=api_key, base_url=endpoint)
124
- system_prompt_path = __file__.replace("app.py", "diffusion.txt")
125
  else:
126
  print("Invalid API key")
127
  return "Invalid API key", "Invalid API key", None
@@ -129,11 +131,11 @@ def generate_second_response(endpoint, api_key, model, generated_output):
129
  if model in claude_models:
130
  # Set the Anthropic API key
131
  client = Anthropic(api_key=api_key, base_url=endpoint)
132
- system_prompt_path = __file__.replace("app.py", "diffusion.txt")
133
  else:
134
  # Set the OpenAI API key
135
  client = OpenAI(api_key=api_key, base_url=endpoint)
136
- system_prompt_path = __file__.replace("app.py", "diffusion.txt")
137
 
138
  # Read the system prompt from a text file
139
  with open(system_prompt_path, "r") as file:
@@ -192,6 +194,7 @@ with gr.Blocks() as demo:
192
 
193
  with gr.Row():
194
  with gr.Column():
 
195
  generate_button_2 = gr.Button("Generate SDXL Prompt")
196
 
197
  with gr.Column():
@@ -208,6 +211,6 @@ with gr.Blocks() as demo:
208
  api_key.change(update_models, inputs=api_key, outputs=[model_dropdown, endpoint])
209
 
210
  generate_button.click(generate_response, inputs=[endpoint, api_key, model_dropdown, user_prompt], outputs=[generated_output, json_output, json_download])
211
- generate_button_2.click(generate_second_response, inputs=[endpoint, api_key, model_dropdown, generated_output], outputs=generated_output_2)
212
 
213
  demo.launch()
 
14
  # List of available OpenAI models
15
  openai_models = ["gpt-4", "gpt-4-32k", "gpt-3.5-turbo", "gpt-4-0125-preview", "gpt-4-turbo-preview", "gpt-4-1106-preview", "gpt-4-0613"]
16
 
17
+ image_prompter = ["SDXL", "midjourney"]
18
+
19
  both_models = claude_models + openai_models
20
 
21
  def generate_response(endpoint, api_key, model, user_prompt):
 
115
  print(e)
116
  return None, None
117
 
118
+ def generate_second_response(endpoint, api_key, model, generated_output, image_model):
119
  if endpoint in default_urls:
120
  #check api keys as normal
121
  if api_key.startswith("sk-ant-"):
122
  client = Anthropic(api_key=api_key, base_url=endpoint)
123
+ system_prompt_path = __file__.replace("app.py", f"{image_model}.txt")
124
  elif api_key.startswith("sk-"):
125
  client = OpenAI(api_key=api_key, base_url=endpoint)
126
+ system_prompt_path = __file__.replace("app.py", f"{image_model}.txt")
127
  else:
128
  print("Invalid API key")
129
  return "Invalid API key", "Invalid API key", None
 
131
  if model in claude_models:
132
  # Set the Anthropic API key
133
  client = Anthropic(api_key=api_key, base_url=endpoint)
134
+ system_prompt_path = __file__.replace("app.py", f"{image_model}.txt")
135
  else:
136
  # Set the OpenAI API key
137
  client = OpenAI(api_key=api_key, base_url=endpoint)
138
+ system_prompt_path = __file__.replace("app.py", f"{image_model}.txt")
139
 
140
  # Read the system prompt from a text file
141
  with open(system_prompt_path, "r") as file:
 
194
 
195
  with gr.Row():
196
  with gr.Column():
197
+ image_model = gr.Dropdown(choices=image_prompter, label="Image Model to prompt for", value="SDXL")
198
  generate_button_2 = gr.Button("Generate SDXL Prompt")
199
 
200
  with gr.Column():
 
211
  api_key.change(update_models, inputs=api_key, outputs=[model_dropdown, endpoint])
212
 
213
  generate_button.click(generate_response, inputs=[endpoint, api_key, model_dropdown, user_prompt], outputs=[generated_output, json_output, json_download])
214
+ generate_button_2.click(generate_second_response, inputs=[endpoint, api_key, model_dropdown, generated_output, image_model], outputs=generated_output_2)
215
 
216
  demo.launch()