AkiKagura commited on
Commit
6c2c213
1 Parent(s): 36e81f7

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +4 -12
app.py CHANGED
@@ -8,7 +8,7 @@ from io import BytesIO
8
  import os
9
  MY_SECRET_TOKEN=os.environ.get('HF_TOKEN_SD')
10
 
11
- from diffusers import StableDiffusionPipeline
12
  from diffusers import StableDiffusionImg2ImgPipeline
13
 
14
  def empty_checker(images, **kwargs): return images, False
@@ -24,11 +24,6 @@ img_pipe = StableDiffusionImg2ImgPipeline.from_pretrained("AkiKagura/mkgen-diffu
24
  img_pipe.safety_checker = empty_checker
25
  img_pipe.to(device)
26
 
27
- # txt2img pipeline
28
- pipe = StableDiffusionPipeline.from_pretrained("AkiKagura/mkgen-diffusion", duse_auth_token=YOUR_TOKEN)
29
- pipe.safety_checker = empty_checker
30
- pipe.to(device)
31
-
32
  source_img = gr.Image(source="upload", type="filepath", label="init_img | 512*512 px")
33
  gallery = gr.Gallery(label="Generated images", show_label=False, elem_id="gallery").style(grid=[1], height="auto")
34
 
@@ -45,12 +40,9 @@ def resize(value,img):
45
  def infer(source_img, prompt, guide, steps, seed, strength):
46
  generator = torch.Generator('cpu').manual_seed(seed)
47
 
48
- if source_image is None:
49
- images_list = pipe([prompt] * 1, guidance_scale=guide, num_inference_steps=steps)
50
- else:
51
- source_image = resize(512, source_img)
52
- source_image.save('source.png')
53
- images_list = img_pipe([prompt] * 1, init_image=source_image, strength=strength, guidance_scale=guide, num_inference_steps=steps)
54
  images = []
55
 
56
  for i, image in enumerate(images_list["images"]):
 
8
  import os
9
  MY_SECRET_TOKEN=os.environ.get('HF_TOKEN_SD')
10
 
11
+ #from diffusers import StableDiffusionPipeline
12
  from diffusers import StableDiffusionImg2ImgPipeline
13
 
14
  def empty_checker(images, **kwargs): return images, False
 
24
  img_pipe.safety_checker = empty_checker
25
  img_pipe.to(device)
26
 
 
 
 
 
 
27
  source_img = gr.Image(source="upload", type="filepath", label="init_img | 512*512 px")
28
  gallery = gr.Gallery(label="Generated images", show_label=False, elem_id="gallery").style(grid=[1], height="auto")
29
 
 
40
  def infer(source_img, prompt, guide, steps, seed, strength):
41
  generator = torch.Generator('cpu').manual_seed(seed)
42
 
43
+ source_image = resize(512, source_img)
44
+ source_image.save('source.png')
45
+ images_list = img_pipe([prompt] * 1, init_image=source_image, strength=strength, guidance_scale=guide, num_inference_steps=steps)
 
 
 
46
  images = []
47
 
48
  for i, image in enumerate(images_list["images"]):