soiz commited on
Commit
e940c08
β€’
1 Parent(s): 8c4ccaa

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +18 -6
app.py CHANGED
@@ -27,7 +27,7 @@ def load_model(model_name):
27
  print(f"Error loading model {model_name}: {error}")
28
  models_load[model_name] = gr.Interface(lambda txt: None, ['text'], ['image'])
29
 
30
- def gen_fn(model_str, prompt, negative_prompt=None, noise=None, cfg_scale=None):
31
  if model_str not in models_load:
32
  load_model(model_str) # γƒ’γƒ‡γƒ«γŒγƒ­γƒΌγƒ‰γ•γ‚Œγ¦γ„γͺγ„ε ΄εˆγ―γƒ­γƒΌγƒ‰γ™γ‚‹
33
 
@@ -38,11 +38,16 @@ def gen_fn(model_str, prompt, negative_prompt=None, noise=None, cfg_scale=None):
38
  try:
39
  if negative_prompt:
40
  full_prompt += f' -{negative_prompt}'
41
- # Adjust function call based on whether cfg_scale is provided
 
 
42
  if cfg_scale is not None:
43
- result = models_load[model_str](full_prompt, cfg_scale=cfg_scale)
44
- else:
45
- result = models_load[model_str](full_prompt)
 
 
 
46
  # Check if result is an image or a file path
47
  if isinstance(result, str): # Assuming result might be a file path
48
  if os.path.exists(result):
@@ -69,12 +74,19 @@ def home():
69
  negative_prompt = request.args.get('Nprompt', None)
70
  noise = request.args.get('noise', None)
71
  cfg_scale = request.args.get('cfg_scale', None)
 
72
 
73
  try:
74
  if cfg_scale is not None:
75
  cfg_scale = float(cfg_scale)
76
  except ValueError:
77
  return 'Invalid "cfg_scale" parameter. It should be a number.', 400
 
 
 
 
 
 
78
 
79
  if not model:
80
  return 'Please provide a "model" query parameter in the URL.', 400
@@ -83,7 +95,7 @@ def home():
83
  return 'Please provide a "prompt" query parameter in the URL.', 400
84
 
85
  # Generate the image
86
- image = gen_fn(model, prompt, negative_prompt, noise, cfg_scale)
87
  if isinstance(image, Image.Image): # Ensure the result is a PIL image
88
  # Save image to BytesIO object
89
  img_io = BytesIO()
 
27
  print(f"Error loading model {model_name}: {error}")
28
  models_load[model_name] = gr.Interface(lambda txt: None, ['text'], ['image'])
29
 
30
+ def gen_fn(model_str, prompt, negative_prompt=None, noise=None, cfg_scale=None, num_inference_steps=None):
31
  if model_str not in models_load:
32
  load_model(model_str) # γƒ’γƒ‡γƒ«γŒγƒ­γƒΌγƒ‰γ•γ‚Œγ¦γ„γͺγ„ε ΄εˆγ―γƒ­γƒΌγƒ‰γ™γ‚‹
33
 
 
38
  try:
39
  if negative_prompt:
40
  full_prompt += f' -{negative_prompt}'
41
+
42
+ # Construct the function call parameters dynamically
43
+ call_params = {'text': full_prompt}
44
  if cfg_scale is not None:
45
+ call_params['cfg_scale'] = cfg_scale
46
+ if num_inference_steps is not None:
47
+ call_params['num_inference_steps'] = num_inference_steps
48
+
49
+ result = models_load[model_str](**call_params)
50
+
51
  # Check if result is an image or a file path
52
  if isinstance(result, str): # Assuming result might be a file path
53
  if os.path.exists(result):
 
74
  negative_prompt = request.args.get('Nprompt', None)
75
  noise = request.args.get('noise', None)
76
  cfg_scale = request.args.get('cfg_scale', None)
77
+ num_inference_steps = request.args.get('steps', None)
78
 
79
  try:
80
  if cfg_scale is not None:
81
  cfg_scale = float(cfg_scale)
82
  except ValueError:
83
  return 'Invalid "cfg_scale" parameter. It should be a number.', 400
84
+
85
+ try:
86
+ if num_inference_steps is not None:
87
+ num_inference_steps = int(num_inference_steps)
88
+ except ValueError:
89
+ return 'Invalid "steps" parameter. It should be an integer.', 400
90
 
91
  if not model:
92
  return 'Please provide a "model" query parameter in the URL.', 400
 
95
  return 'Please provide a "prompt" query parameter in the URL.', 400
96
 
97
  # Generate the image
98
+ image = gen_fn(model, prompt, negative_prompt, noise, cfg_scale, num_inference_steps)
99
  if isinstance(image, Image.Image): # Ensure the result is a PIL image
100
  # Save image to BytesIO object
101
  img_io = BytesIO()