abi445 commited on
Commit
d5677ef
·
1 Parent(s): ecedd22

Add application file

Browse files
Files changed (2) hide show
  1. README.md +1 -1
  2. app.py +3 -14
README.md CHANGED
@@ -2,7 +2,7 @@
2
  title: IllusionDiffusion
3
  emoji: 🔥
4
  colorFrom: green
5
- colorTo: pink
6
  sdk: gradio
7
  sdk_version: 4.36.1
8
  app_file: app.py
 
2
  title: IllusionDiffusion
3
  emoji: 🔥
4
  colorFrom: green
5
+ colorTo: red
6
  sdk: gradio
7
  sdk_version: 4.36.1
8
  app_file: app.py
app.py CHANGED
@@ -27,16 +27,14 @@ from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionS
27
 
28
  BASE_MODEL = "SG161222/Realistic_Vision_V5.1_noVAE"
29
 
30
- # Initialize both pipelines
31
  vae = AutoencoderKL.from_pretrained("stabilityai/sd-vae-ft-mse", torch_dtype=torch.float32)
32
  controlnet = ControlNetModel.from_pretrained("monster-labs/control_v1p_sd15_qrcode_monster", torch_dtype=torch.float32)
33
 
34
- # Initialize the safety checker conditionally
35
  SAFETY_CHECKER_ENABLED = os.environ.get("SAFETY_CHECKER", "0") == "1"
36
  safety_checker = None
37
  feature_extractor = None
38
  if SAFETY_CHECKER_ENABLED:
39
- safety_checker = StableDiffusionSafetyChecker.from_pretrained("CompVis/stable-diffusion-safety-checker").to("cpu")
40
  feature_extractor = CLIPImageProcessor.from_pretrained("openai/clip-vit-base-patch32")
41
 
42
  main_pipe = StableDiffusionControlNetPipeline.from_pretrained(
@@ -46,11 +44,10 @@ main_pipe = StableDiffusionControlNetPipeline.from_pretrained(
46
  safety_checker=safety_checker,
47
  feature_extractor=feature_extractor,
48
  torch_dtype=torch.float32,
49
- ).to("cpu")
50
 
51
  image_pipe = StableDiffusionControlNetImg2ImgPipeline(**main_pipe.components)
52
 
53
- # Sampler map
54
  SAMPLER_MAP = {
55
  "DPM++ Karras SDE": lambda config: DPMSolverMultistepScheduler.from_config(config, use_karras=True, algorithm_type="sde-dpmsolver++"),
56
  "Euler": lambda config: EulerDiscreteScheduler.from_config(config),
@@ -58,18 +55,13 @@ SAMPLER_MAP = {
58
 
59
  def center_crop_resize(img, output_size=(512, 512)):
60
  width, height = img.size
61
-
62
- # Calculate dimensions to crop to the center
63
  new_dimension = min(width, height)
64
  left = (width - new_dimension)/2
65
  top = (height - new_dimension)/2
66
  right = (width + new_dimension)/2
67
  bottom = (height + new_dimension)/2
68
-
69
- # Crop and resize
70
  img = img.crop((left, top, right, bottom))
71
  img = img.resize(output_size)
72
-
73
  return img
74
 
75
  def common_upscale(samples, width, height, upscale_method, crop=False):
@@ -87,7 +79,6 @@ def common_upscale(samples, width, height, upscale_method, crop=False):
87
  s = samples[:,:,y:old_height-y,x:old_width-x]
88
  else:
89
  s = samples
90
-
91
  return torch.nn.functional.interpolate(s, size=(height, width), mode=upscale_method)
92
 
93
  def upscale(samples, upscale_method, scale_by):
@@ -111,7 +102,6 @@ def convert_to_base64(pil_image):
111
  image.save(temp_file.name)
112
  return temp_file.name
113
 
114
- # Inference function
115
  def inference(
116
  control_image: Image.Image,
117
  prompt: str,
@@ -136,7 +126,7 @@ def inference(
136
 
137
  main_pipe.scheduler = SAMPLER_MAP[sampler](main_pipe.scheduler.config)
138
  my_seed = random.randint(0, 2**32 - 1) if seed == -1 else seed
139
- generator = torch.Generator(device="cpu").manual_seed(my_seed)
140
 
141
  out = main_pipe(
142
  prompt=prompt,
@@ -169,7 +159,6 @@ def inference(
169
  end_time_formatted = time.strftime("%H:%M:%S", end_time_struct)
170
  print(f"Inference ended at {end_time_formatted}, taking {end_time-start_time}s")
171
 
172
- # Save image + metadata
173
  user_history.save_image(
174
  label=prompt,
175
  image=out_image["images"][0],
 
27
 
28
  BASE_MODEL = "SG161222/Realistic_Vision_V5.1_noVAE"
29
 
 
30
  vae = AutoencoderKL.from_pretrained("stabilityai/sd-vae-ft-mse", torch_dtype=torch.float32)
31
  controlnet = ControlNetModel.from_pretrained("monster-labs/control_v1p_sd15_qrcode_monster", torch_dtype=torch.float32)
32
 
 
33
  SAFETY_CHECKER_ENABLED = os.environ.get("SAFETY_CHECKER", "0") == "1"
34
  safety_checker = None
35
  feature_extractor = None
36
  if SAFETY_CHECKER_ENABLED:
37
+ safety_checker = StableDiffusionSafetyChecker.from_pretrained("CompVis/stable-diffusion-safety-checker")
38
  feature_extractor = CLIPImageProcessor.from_pretrained("openai/clip-vit-base-patch32")
39
 
40
  main_pipe = StableDiffusionControlNetPipeline.from_pretrained(
 
44
  safety_checker=safety_checker,
45
  feature_extractor=feature_extractor,
46
  torch_dtype=torch.float32,
47
+ )
48
 
49
  image_pipe = StableDiffusionControlNetImg2ImgPipeline(**main_pipe.components)
50
 
 
51
  SAMPLER_MAP = {
52
  "DPM++ Karras SDE": lambda config: DPMSolverMultistepScheduler.from_config(config, use_karras=True, algorithm_type="sde-dpmsolver++"),
53
  "Euler": lambda config: EulerDiscreteScheduler.from_config(config),
 
55
 
56
  def center_crop_resize(img, output_size=(512, 512)):
57
  width, height = img.size
 
 
58
  new_dimension = min(width, height)
59
  left = (width - new_dimension)/2
60
  top = (height - new_dimension)/2
61
  right = (width + new_dimension)/2
62
  bottom = (height + new_dimension)/2
 
 
63
  img = img.crop((left, top, right, bottom))
64
  img = img.resize(output_size)
 
65
  return img
66
 
67
  def common_upscale(samples, width, height, upscale_method, crop=False):
 
79
  s = samples[:,:,y:old_height-y,x:old_width-x]
80
  else:
81
  s = samples
 
82
  return torch.nn.functional.interpolate(s, size=(height, width), mode=upscale_method)
83
 
84
  def upscale(samples, upscale_method, scale_by):
 
102
  image.save(temp_file.name)
103
  return temp_file.name
104
 
 
105
  def inference(
106
  control_image: Image.Image,
107
  prompt: str,
 
126
 
127
  main_pipe.scheduler = SAMPLER_MAP[sampler](main_pipe.scheduler.config)
128
  my_seed = random.randint(0, 2**32 - 1) if seed == -1 else seed
129
+ generator = torch.Generator().manual_seed(my_seed)
130
 
131
  out = main_pipe(
132
  prompt=prompt,
 
159
  end_time_formatted = time.strftime("%H:%M:%S", end_time_struct)
160
  print(f"Inference ended at {end_time_formatted}, taking {end_time-start_time}s")
161
 
 
162
  user_history.save_image(
163
  label=prompt,
164
  image=out_image["images"][0],