saikub commited on
Commit
c46f894
1 Parent(s): f1eee30

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +37 -10
app.py CHANGED
@@ -6,19 +6,40 @@ import torch
6
 
7
  device = "cuda" if torch.cuda.is_available() else "cpu"
8
 
9
- if torch.cuda.is_available():
10
- torch.cuda.max_memory_allocated(device=device)
11
- pipe = DiffusionPipeline.from_pretrained("stabilityai/sdxl-turbo", torch_dtype=torch.float16, variant="fp16", use_safetensors=True)
12
- pipe.enable_xformers_memory_efficient_attention()
13
- pipe = pipe.to(device)
14
- else:
15
- pipe = DiffusionPipeline.from_pretrained("stabilityai/sdxl-turbo", use_safetensors=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
16
  pipe = pipe.to(device)
 
 
 
17
 
18
  MAX_SEED = np.iinfo(np.int32).max
19
  MAX_IMAGE_SIZE = 1024
20
 
21
- def infer(prompt, negative_prompt, seed, randomize_seed, width, height, guidance_scale, num_inference_steps):
 
22
 
23
  if randomize_seed:
24
  seed = random.randint(0, MAX_SEED)
@@ -64,6 +85,12 @@ with gr.Blocks(css=css) as demo:
64
  """)
65
 
66
  with gr.Row():
 
 
 
 
 
 
67
 
68
  prompt = gr.Text(
69
  label="Prompt",
@@ -139,8 +166,8 @@ with gr.Blocks(css=css) as demo:
139
 
140
  run_button.click(
141
  fn = infer,
142
- inputs = [prompt, negative_prompt, seed, randomize_seed, width, height, guidance_scale, num_inference_steps],
143
  outputs = [result]
144
  )
145
 
146
- demo.queue().launch()
 
6
 
7
  device = "cuda" if torch.cuda.is_available() else "cpu"
8
 
9
+ # List of models
10
+ models = {
11
+ "sdxl-turbo": "stabilityai/sdxl-turbo",
12
+ "MistoLine": "TheMistoAI/MistoLine"
13
+ }
14
+
15
+ # Cache to store loaded pipelines
16
+ pipelines = {}
17
+
18
+ # Function to load a model
19
+ def load_model(model_name):
20
+ if model_name in pipelines:
21
+ return pipelines[model_name]
22
+
23
+ if model_name not in models:
24
+ raise ValueError(f"Model {model_name} is not available.")
25
+
26
+ model_path = models[model_name]
27
+ if torch.cuda.is_available():
28
+ torch.cuda.max_memory_allocated(device=device)
29
+ pipe = DiffusionPipeline.from_pretrained(model_path, torch_dtype=torch.float16, variant="fp16", use_safetensors=True)
30
+ pipe.enable_xformers_memory_efficient_attention()
31
+ else:
32
+ pipe = DiffusionPipeline.from_pretrained(model_path, use_safetensors=True)
33
  pipe = pipe.to(device)
34
+
35
+ pipelines[model_name] = pipe
36
+ return pipe
37
 
38
  MAX_SEED = np.iinfo(np.int32).max
39
  MAX_IMAGE_SIZE = 1024
40
 
41
+ def infer(model_name, prompt, negative_prompt, seed, randomize_seed, width, height, guidance_scale, num_inference_steps):
42
+ pipe = load_model(model_name)
43
 
44
  if randomize_seed:
45
  seed = random.randint(0, MAX_SEED)
 
85
  """)
86
 
87
  with gr.Row():
88
+ model_name = gr.Dropdown(
89
+ label="Select Model",
90
+ choices=list(models.keys()),
91
+ value="sdxl-turbo",
92
+ show_label=True
93
+ )
94
 
95
  prompt = gr.Text(
96
  label="Prompt",
 
166
 
167
  run_button.click(
168
  fn = infer,
169
+ inputs = [model_name, prompt, negative_prompt, seed, randomize_seed, width, height, guidance_scale, num_inference_steps],
170
  outputs = [result]
171
  )
172
 
173
+ demo.queue().launch()