Spaces:
Running
on
Zero
Running
on
Zero
BlockDetail
commited on
Commit
•
3214d99
1
Parent(s):
162c342
env
Browse files
app.py
CHANGED
@@ -4,6 +4,7 @@ import torch
|
|
4 |
import numpy as np
|
5 |
from PIL import Image, ImageFilter
|
6 |
from extension import CustomStableDiffusionControlNetPipeline
|
|
|
7 |
|
8 |
negative_prompt = ""
|
9 |
device = torch.device('cuda')
|
@@ -11,6 +12,7 @@ pipe = None
|
|
11 |
|
12 |
print(gr.__version__)
|
13 |
|
|
|
14 |
def load():
|
15 |
global pipe
|
16 |
controlnet = ControlNetModel.from_pretrained("BlockDetail/PartialSketchControlNet", torch_dtype=torch.float16).to(device)
|
@@ -54,6 +56,7 @@ with gr.Blocks() as demo:
|
|
54 |
sketch_states = gr.State(start_state)
|
55 |
checkbox_state = gr.State(True)
|
56 |
|
|
|
57 |
def sketch(curr_sketch_image, dilation_mask, prompt, seed, num_steps, dilation):
|
58 |
global curr_num_samples
|
59 |
global pipe
|
@@ -159,5 +162,5 @@ with gr.Blocks() as demo:
|
|
159 |
stroke_type[0].change(change_color, [stroke_type[0]], canvas)
|
160 |
num_samples[0].change(change_num_samples, [num_samples[0]], None)
|
161 |
|
162 |
-
|
163 |
demo.launch(share = True, debug = True)
|
|
|
4 |
import numpy as np
|
5 |
from PIL import Image, ImageFilter
|
6 |
from extension import CustomStableDiffusionControlNetPipeline
|
7 |
+
import spaces
|
8 |
|
9 |
negative_prompt = ""
|
10 |
device = torch.device('cuda')
|
|
|
12 |
|
13 |
print(gr.__version__)
|
14 |
|
15 |
+
@spaces.GPU
|
16 |
def load():
|
17 |
global pipe
|
18 |
controlnet = ControlNetModel.from_pretrained("BlockDetail/PartialSketchControlNet", torch_dtype=torch.float16).to(device)
|
|
|
56 |
sketch_states = gr.State(start_state)
|
57 |
checkbox_state = gr.State(True)
|
58 |
|
59 |
+
@spaces.GPU
|
60 |
def sketch(curr_sketch_image, dilation_mask, prompt, seed, num_steps, dilation):
|
61 |
global curr_num_samples
|
62 |
global pipe
|
|
|
162 |
stroke_type[0].change(change_color, [stroke_type[0]], canvas)
|
163 |
num_samples[0].change(change_num_samples, [num_samples[0]], None)
|
164 |
|
165 |
+
load()
|
166 |
demo.launch(share = True, debug = True)
|