amos1088 commited on
Commit
52d3f89
·
1 Parent(s): 6c3f566

test gradio

Browse files
Files changed (1) hide show
  1. app.py +12 -39
app.py CHANGED
@@ -1,13 +1,6 @@
1
  import gradio as gr
2
  import torch
3
- from diffusers import (
4
- StableDiffusion3Pipeline, # For SD3 models like Stable Diffusion 3.5
5
- ControlNetModel,
6
- SD3Transformer2DModel, # Replacing UNet with SD3 transformer
7
- AutoencoderKL,
8
- UniPCMultistepScheduler,
9
- )
10
- from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
11
  from huggingface_hub import login
12
  import os
13
 
@@ -17,41 +10,21 @@ login(token=token)
17
 
18
  # Model IDs for the base Stable Diffusion model and ControlNet variant
19
  model_id = "stabilityai/stable-diffusion-3.5-large-turbo"
20
- controlnet_id = "lllyasviel/control_v11p_sd15_inpaint"
21
-
22
- # Load each model component required by the pipeline
23
- controlnet = ControlNetModel.from_pretrained(controlnet_id, torch_dtype=torch.float16)
24
- transformer = SD3Transformer2DModel.from_pretrained(model_id, subfolder="transformer", torch_dtype=torch.float16)
25
- vae = AutoencoderKL.from_pretrained(model_id, subfolder="vae", torch_dtype=torch.float16)
26
- feature_extractor = CLIPFeatureExtractor.from_pretrained(model_id)
27
- text_encoder = CLIPTextModel.from_pretrained(model_id, subfolder="text_encoder")
28
- tokenizer = CLIPTokenizer.from_pretrained(model_id)
29
-
30
- # Initialize the pipeline with all components
31
- pipeline = StableDiffusion3Pipeline(
32
- transformer=transformer, # Using SD3 transformer
33
- vae=vae,
34
- text_encoder=text_encoder,
35
- tokenizer=tokenizer,
36
- controlnet=controlnet,
37
- scheduler=UniPCMultistepScheduler.from_config({"name": "UniPCMultistepScheduler"}),
38
- feature_extractor=feature_extractor,
39
- torch_dtype=torch.float16,
40
- )
41
-
42
- # Set device for pipeline
43
- pipeline = pipeline.to("cuda") if torch.cuda.is_available() else pipeline
44
 
45
- # Enable model CPU offloading for memory optimization
46
- pipeline.enable_model_cpu_offload()
 
 
 
47
 
48
  # Gradio interface function
49
  def generate_image(prompt, reference_image):
50
- # Resize and prepare reference image
51
  reference_image = reference_image.convert("RGB").resize((512, 512))
52
 
53
- # Generate image using the pipeline with ControlNet
54
- generated_image = pipeline(
55
  prompt=prompt,
56
  image=reference_image,
57
  controlnet_conditioning_scale=1.0,
@@ -68,8 +41,8 @@ interface = gr.Interface(
68
  gr.Image(type="pil", label="Reference Image (Style)")
69
  ],
70
  outputs="image",
71
- title="Image Generation with ControlNet (Reference-Only Style Transfer)",
72
- description="Generates an image based on a text prompt and style reference image using Stable Diffusion 3.5 and ControlNet (reference-only mode)."
73
  )
74
 
75
  # Launch the Gradio interface
 
1
  import gradio as gr
2
  import torch
3
+ from diffusers import StableDiffusion3Pipeline, ControlNetModel, UniPCMultistepScheduler
 
 
 
 
 
 
 
4
  from huggingface_hub import login
5
  import os
6
 
 
10
 
11
  # Model IDs for the base Stable Diffusion model and ControlNet variant
12
  model_id = "stabilityai/stable-diffusion-3.5-large-turbo"
13
+ controlnet_id = "lllyasviel/control_v11p_sd15_inpaint" # Adjust based on ControlNet needs
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14
 
15
+ # Load ControlNet and Stable Diffusion models
16
+ controlnet = ControlNetModel.from_pretrained(controlnet_id, torch_dtype=torch.bfloat16)
17
+ pipe = StableDiffusion3Pipeline.from_pretrained(model_id, controlnet=controlnet, torch_dtype=torch.bfloat16)
18
+ pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config)
19
+ pipe = pipe.to("cuda") if torch.cuda.is_available() else pipe
20
 
21
  # Gradio interface function
22
  def generate_image(prompt, reference_image):
23
+ # Prepare the reference image
24
  reference_image = reference_image.convert("RGB").resize((512, 512))
25
 
26
+ # Generate the image using the pipeline with ControlNet
27
+ generated_image = pipe(
28
  prompt=prompt,
29
  image=reference_image,
30
  controlnet_conditioning_scale=1.0,
 
41
  gr.Image(type="pil", label="Reference Image (Style)")
42
  ],
43
  outputs="image",
44
+ title="Image Generation with Stable Diffusion 3.5 and ControlNet",
45
+ description="Generates an image based on a text prompt and style reference image using Stable Diffusion 3.5 and ControlNet."
46
  )
47
 
48
  # Launch the Gradio interface