amos1088 commited on
Commit
0737dc8
·
1 Parent(s): f954913

test gradio

Browse files
Files changed (1) hide show
  1. app.py +31 -28
app.py CHANGED
@@ -1,43 +1,45 @@
1
  import gradio as gr
2
  import torch
3
- from diffusers import StableDiffusionControlNetPipeline, ControlNetModel, UNet2DConditionModel, AutoencoderKL, DDIMScheduler
 
 
 
 
 
 
4
  from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
5
- import os
6
  from huggingface_hub import login
 
7
 
8
- # Authenticate with Hugging Face
9
  token = os.getenv("HF_TOKEN")
10
  login(token=token)
11
 
12
- # Define model and controlnet IDs
13
- model_id = "runwayml/stable-diffusion-v1-5" # Use a fully compatible model
14
- controlnet_id = "lllyasviel/control_v11p_sd15_inpaint" # ControlNet variant
15
-
16
- # Load ControlNet and other components
17
- controlnet = ControlNetModel.from_pretrained(controlnet_id, torch_dtype=torch.float32)
18
- unet = UNet2DConditionModel.from_pretrained(model_id, subfolder="unet")
19
- vae = AutoencoderKL.from_pretrained(model_id, subfolder="vae")
20
- scheduler = DDIMScheduler.from_pretrained(model_id, subfolder="scheduler")
21
- text_encoder = CLIPTextModel.from_pretrained("openai/clip-vit-base-patch32")
22
- tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-base-patch32")
23
- feature_extractor = CLIPFeatureExtractor.from_pretrained("openai/clip-vit-base-patch32")
24
-
25
- # Initialize the pipeline with all required components
26
- pipeline = StableDiffusionControlNetPipeline(
27
- vae=vae,
28
- text_encoder=text_encoder,
29
- tokenizer=tokenizer,
30
- unet=unet,
31
  controlnet=controlnet,
32
- scheduler=scheduler,
33
- feature_extractor=feature_extractor
34
  )
35
 
36
- pipeline = pipeline.to("cuda") if torch.cuda.is_available() else pipeline
 
37
 
38
- # Define Gradio interface function
 
 
 
 
39
  def generate_image(prompt, reference_image):
 
40
  reference_image = reference_image.convert("RGB").resize((512, 512))
 
 
41
  generated_image = pipeline(
42
  prompt=prompt,
43
  image=reference_image,
@@ -47,6 +49,7 @@ def generate_image(prompt, reference_image):
47
  ).images[0]
48
  return generated_image
49
 
 
50
  # Set up Gradio interface
51
  interface = gr.Interface(
52
  fn=generate_image,
@@ -55,8 +58,8 @@ interface = gr.Interface(
55
  gr.Image(type="pil", label="Reference Image (Style)")
56
  ],
57
  outputs="image",
58
- title="Image Generation with Reference-Only Style Transfer",
59
- description="Generate an image based on a text prompt and style reference image using Stable Diffusion with ControlNet."
60
  )
61
 
62
  # Launch the Gradio interface
 
1
  import gradio as gr
2
  import torch
3
+ from diffusers import (
4
+ StableDiffusionControlNetPipeline,
5
+ ControlNetModel,
6
+ UNet2DConditionModel,
7
+ AutoencoderKL,
8
+ UniPCMultistepScheduler,
9
+ )
10
  from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
 
11
  from huggingface_hub import login
12
+ import os
13
 
14
+ # Log in to Hugging Face with token from environment variables
15
  token = os.getenv("HF_TOKEN")
16
  login(token=token)
17
 
18
+ # Model and ControlNet IDs
19
+ model_id = "runwayml/stable-diffusion-v1-5" # Known compatible model with ControlNet
20
+ controlnet_id = "lllyasviel/sd-controlnet-canny" # ControlNet model for edge detection
21
+
22
+ # Load ControlNet model and other components
23
+ controlnet = ControlNetModel.from_pretrained(controlnet_id, torch_dtype=torch.float16)
24
+ pipeline = StableDiffusionControlNetPipeline.from_pretrained(
25
+ model_id,
 
 
 
 
 
 
 
 
 
 
 
26
  controlnet=controlnet,
27
+ torch_dtype=torch.float16
 
28
  )
29
 
30
+ # Optional: Set up the faster scheduler
31
+ pipeline.scheduler = UniPCMultistepScheduler.from_config(pipeline.scheduler.config)
32
 
33
+ # Enable CPU offloading for memory optimization
34
+ pipeline.enable_model_cpu_offload()
35
+
36
+
37
+ # Gradio interface function
38
  def generate_image(prompt, reference_image):
39
+ # Resize and prepare reference image
40
  reference_image = reference_image.convert("RGB").resize((512, 512))
41
+
42
+ # Generate image using the pipeline with ControlNet
43
  generated_image = pipeline(
44
  prompt=prompt,
45
  image=reference_image,
 
49
  ).images[0]
50
  return generated_image
51
 
52
+
53
  # Set up Gradio interface
54
  interface = gr.Interface(
55
  fn=generate_image,
 
58
  gr.Image(type="pil", label="Reference Image (Style)")
59
  ],
60
  outputs="image",
61
+ title="Image Generation with ControlNet (Reference-Only Style Transfer)",
62
+ description="Generates an image based on a text prompt and style reference image using Stable Diffusion and ControlNet (reference-only mode)."
63
  )
64
 
65
  # Launch the Gradio interface