jepz commited on
Commit
ad36db4
·
verified ·
1 Parent(s): 70660bf

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +26 -9
app.py CHANGED
@@ -1,9 +1,10 @@
1
  import gradio as gr
2
  import numpy as np
3
  import random
4
- from diffusers import DiffusionPipeline
5
  from transformers.utils.hub import move_cache
6
  import torch
 
7
 
8
  move_cache()
9
 
@@ -11,17 +12,28 @@ device = "cuda" if torch.cuda.is_available() else "cpu"
11
 
12
  if torch.cuda.is_available():
13
  torch.cuda.max_memory_allocated(device=device)
14
- pipe = DiffusionPipeline.from_pretrained("Envvi/Inkpunk-Diffusion", torch_dtype=torch.float16, variant="fp16")
 
15
  pipe.enable_xformers_memory_efficient_attention()
16
  pipe = pipe.to(device)
17
  else:
18
- pipe = DiffusionPipeline.from_pretrained("Envvi/Inkpunk-Diffusion")
19
  pipe = pipe.to(device)
20
 
21
  MAX_SEED = np.iinfo(np.int32).max
22
  MAX_IMAGE_SIZE = 1024
23
 
24
- def infer(prompt, negative_prompt, seed, randomize_seed, width, height, guidance_scale, num_inference_steps):
 
 
 
 
 
 
 
 
 
 
25
 
26
  if randomize_seed:
27
  seed = random.randint(0, MAX_SEED)
@@ -29,10 +41,11 @@ def infer(prompt, negative_prompt, seed, randomize_seed, width, height, guidance
29
  generator = torch.Generator().manual_seed(seed)
30
 
31
  image = pipe(
 
32
  prompt = prompt,
33
  negative_prompt = negative_prompt,
34
  guidance_scale = guidance_scale,
35
- num_inference_steps = num_inference_steps,
36
  width = width,
37
  height = height,
38
  generator = generator
@@ -65,6 +78,10 @@ with gr.Blocks(css=css) as demo:
65
  # Text-to-Image Gradio Template
66
  Currently running on {power_device}.
67
  """)
 
 
 
 
68
 
69
  with gr.Row():
70
 
@@ -127,10 +144,10 @@ with gr.Blocks(css=css) as demo:
127
  value=0.0,
128
  )
129
 
130
- num_inference_steps = gr.Slider(
131
- label="Number of inference steps",
132
  minimum=1,
133
- maximum=12,
134
  step=1,
135
  value=2,
136
  )
@@ -142,7 +159,7 @@ with gr.Blocks(css=css) as demo:
142
 
143
  run_button.click(
144
  fn = infer,
145
- inputs = [prompt, negative_prompt, seed, randomize_seed, width, height, guidance_scale, num_inference_steps],
146
  outputs = [result]
147
  )
148
 
 
1
  import gradio as gr
2
  import numpy as np
3
  import random
4
+ from diffusers import DiffusionPipeline, StableDiffusionImg2ImgPipeline
5
  from transformers.utils.hub import move_cache
6
  import torch
7
+ from PIL import Image
8
 
9
  move_cache()
10
 
 
12
 
13
  if torch.cuda.is_available():
14
  torch.cuda.max_memory_allocated(device=device)
15
+ pipe = StableDiffusionImg2ImgPipeline.from_pretrained("Envvi/Inkpunk-Diffusion", torch_dtype=torch.float16)
16
+ #pipe = DiffusionPipeline.from_pretrained("Envvi/Inkpunk-Diffusion", torch_dtype=torch.float16, variant="fp16")
17
  pipe.enable_xformers_memory_efficient_attention()
18
  pipe = pipe.to(device)
19
  else:
20
+ pipe = StableDiffusionImg2ImgPipeline.from_pretrained("Envvi/Inkpunk-Diffusion", torch_dtype=torch.float16)
21
  pipe = pipe.to(device)
22
 
23
  MAX_SEED = np.iinfo(np.int32).max
24
  MAX_IMAGE_SIZE = 1024
25
 
26
+ def generate_image(uploaded_image):
27
+ # Open the uploaded image
28
+ image = Image.open(uploaded_image)
29
+
30
+ # Run the image through the Stable Diffusion pipeline
31
+ with torch.no_grad():
32
+ output = pipe(image, guidance_scale=7.5)["sample"][0]
33
+
34
+ return output
35
+
36
+ def infer(base_img, prompt, negative_prompt, seed, randomize_seed, width, height, guidance_scale, strength):
37
 
38
  if randomize_seed:
39
  seed = random.randint(0, MAX_SEED)
 
41
  generator = torch.Generator().manual_seed(seed)
42
 
43
  image = pipe(
44
+ image = base_img,
45
  prompt = prompt,
46
  negative_prompt = negative_prompt,
47
  guidance_scale = guidance_scale,
48
+ strength = strength,
49
  width = width,
50
  height = height,
51
  generator = generator
 
78
  # Text-to-Image Gradio Template
79
  Currently running on {power_device}.
80
  """)
81
+
82
+
83
+ with gr.Row():
84
+ base_img = gr.Interface(fn=generate_image, inputs=gr.inputs.Image(type="pil"), outputs=gr.outputs.Image(type="pil"))
85
 
86
  with gr.Row():
87
 
 
144
  value=0.0,
145
  )
146
 
147
+ strength = gr.Slider(
148
+ label="strength",
149
  minimum=1,
150
+ maximum=10,
151
  step=1,
152
  value=2,
153
  )
 
159
 
160
  run_button.click(
161
  fn = infer,
162
+ inputs = [base_img, prompt, negative_prompt, seed, randomize_seed, width, height, guidance_scale, strength],
163
  outputs = [result]
164
  )
165