BlockDetail commited on
Commit
7021612
1 Parent(s): 5a6766c
Files changed (1) hide show
  1. app.py +15 -8
app.py CHANGED
@@ -7,13 +7,18 @@ from extension import CustomStableDiffusionControlNetPipeline
7
 
8
  negative_prompt = ""
9
  device = torch.device('cuda')
10
- controlnet = ControlNetModel.from_pretrained("BlockDetail/PartialSketchControlNet", torch_dtype=torch.float16).to(device)
11
- pipe = CustomStableDiffusionControlNetPipeline.from_pretrained(
12
- "runwayml/stable-diffusion-v1-5",
13
- controlnet=controlnet, torch_dtype=torch.float16
14
- ).to(device)
15
- pipe.safety_checker = None
16
- pipe.scheduler = EulerAncestralDiscreteScheduler.from_config(pipe.scheduler.config)
 
 
 
 
 
17
  threshold = 250
18
  curr_num_samples = 2
19
 
@@ -48,9 +53,11 @@ with gr.Blocks() as demo:
48
  start_state.append([None, None])
49
  sketch_states = gr.State(start_state)
50
  checkbox_state = gr.State(True)
51
-
 
52
  def sketch(curr_sketch_image, dilation_mask, prompt, seed, num_steps, dilation):
53
  global curr_num_samples
 
54
  generator = torch.Generator(device="cuda:0")
55
  generator.manual_seed(seed)
56
 
 
7
 
8
  negative_prompt = ""
9
  device = torch.device('cuda')
10
+ pipe = None
11
+
12
+ +@spaces.GPU
13
+ def load():
14
+ global pipe
15
+ controlnet = ControlNetModel.from_pretrained("BlockDetail/PartialSketchControlNet", torch_dtype=torch.float16).to(device)
16
+ pipe = CustomStableDiffusionControlNetPipeline.from_pretrained(
17
+ "runwayml/stable-diffusion-v1-5",
18
+ controlnet=controlnet, torch_dtype=torch.float16
19
+ ).to(device)
20
+ pipe.safety_checker = None
21
+ pipe.scheduler = EulerAncestralDiscreteScheduler.from_config(pipe.scheduler.config)
22
  threshold = 250
23
  curr_num_samples = 2
24
 
 
53
  start_state.append([None, None])
54
  sketch_states = gr.State(start_state)
55
  checkbox_state = gr.State(True)
56
+
57
+ +@spaces.GPU
58
  def sketch(curr_sketch_image, dilation_mask, prompt, seed, num_steps, dilation):
59
  global curr_num_samples
60
+ global pipe
61
  generator = torch.Generator(device="cuda:0")
62
  generator.manual_seed(seed)
63