amos1088 commited on
Commit
f954913
·
1 Parent(s): c1497a6

test gradio

Browse files
Files changed (1) hide show
  1. app.py +25 -16
app.py CHANGED
@@ -1,33 +1,43 @@
1
  import gradio as gr
2
  import torch
3
- from diffusers import StableDiffusionControlNetPipeline, ControlNetModel
 
4
  import os
5
  from huggingface_hub import login
6
 
7
- # Log in with your Hugging Face token (assumed stored in HF_TOKEN)
8
  token = os.getenv("HF_TOKEN")
9
  login(token=token)
10
 
11
- # Model IDs for the base Stable Diffusion model and ControlNet variant
12
- model_id = "stabilityai/stable-diffusion-3.5-large-turbo"
13
- controlnet_id = "lllyasviel/control_v11p_sd15_inpaint" # Make sure this ControlNet is compatible
14
 
15
- # Load ControlNet model and pipeline
16
  controlnet = ControlNetModel.from_pretrained(controlnet_id, torch_dtype=torch.float32)
17
- pipeline = StableDiffusionControlNetPipeline.from_pretrained(
18
- model_id,
 
 
 
 
 
 
 
 
 
 
 
19
  controlnet=controlnet,
20
- torch_dtype=torch.float32
 
21
  )
22
- pipeline = pipeline.to("cuda") if torch.cuda.is_available() else pipeline
23
 
 
24
 
25
- # Define the Gradio interface function
26
  def generate_image(prompt, reference_image):
27
- # Ensure the reference image is in the correct format
28
  reference_image = reference_image.convert("RGB").resize((512, 512))
29
-
30
- # Generate the image with ControlNet
31
  generated_image = pipeline(
32
  prompt=prompt,
33
  image=reference_image,
@@ -37,7 +47,6 @@ def generate_image(prompt, reference_image):
37
  ).images[0]
38
  return generated_image
39
 
40
-
41
  # Set up Gradio interface
42
  interface = gr.Interface(
43
  fn=generate_image,
@@ -47,7 +56,7 @@ interface = gr.Interface(
47
  ],
48
  outputs="image",
49
  title="Image Generation with Reference-Only Style Transfer",
50
- description="Generate an image based on a text prompt and style reference image using Stable Diffusion 3.5 with ControlNet (reference-only mode)."
51
  )
52
 
53
  # Launch the Gradio interface
 
1
  import gradio as gr
2
  import torch
3
+ from diffusers import StableDiffusionControlNetPipeline, ControlNetModel, UNet2DConditionModel, AutoencoderKL, DDIMScheduler
4
+ from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
5
  import os
6
  from huggingface_hub import login
7
 
8
+ # Authenticate with Hugging Face
9
  token = os.getenv("HF_TOKEN")
10
  login(token=token)
11
 
12
+ # Define model and controlnet IDs
13
+ model_id = "runwayml/stable-diffusion-v1-5" # Use a fully compatible model
14
+ controlnet_id = "lllyasviel/control_v11p_sd15_inpaint" # ControlNet variant
15
 
16
+ # Load ControlNet and other components
17
  controlnet = ControlNetModel.from_pretrained(controlnet_id, torch_dtype=torch.float32)
18
+ unet = UNet2DConditionModel.from_pretrained(model_id, subfolder="unet")
19
+ vae = AutoencoderKL.from_pretrained(model_id, subfolder="vae")
20
+ scheduler = DDIMScheduler.from_pretrained(model_id, subfolder="scheduler")
21
+ text_encoder = CLIPTextModel.from_pretrained("openai/clip-vit-base-patch32")
22
+ tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-base-patch32")
23
+ feature_extractor = CLIPFeatureExtractor.from_pretrained("openai/clip-vit-base-patch32")
24
+
25
+ # Initialize the pipeline with all required components
26
+ pipeline = StableDiffusionControlNetPipeline(
27
+ vae=vae,
28
+ text_encoder=text_encoder,
29
+ tokenizer=tokenizer,
30
+ unet=unet,
31
  controlnet=controlnet,
32
+ scheduler=scheduler,
33
+ feature_extractor=feature_extractor
34
  )
 
35
 
36
+ pipeline = pipeline.to("cuda") if torch.cuda.is_available() else pipeline
37
 
38
+ # Define Gradio interface function
39
  def generate_image(prompt, reference_image):
 
40
  reference_image = reference_image.convert("RGB").resize((512, 512))
 
 
41
  generated_image = pipeline(
42
  prompt=prompt,
43
  image=reference_image,
 
47
  ).images[0]
48
  return generated_image
49
 
 
50
  # Set up Gradio interface
51
  interface = gr.Interface(
52
  fn=generate_image,
 
56
  ],
57
  outputs="image",
58
  title="Image Generation with Reference-Only Style Transfer",
59
+ description="Generate an image based on a text prompt and style reference image using Stable Diffusion with ControlNet."
60
  )
61
 
62
  # Launch the Gradio interface