Spaces:
Paused
Paused
test gradio
Browse files
app.py
CHANGED
@@ -1,33 +1,43 @@
|
|
1 |
import gradio as gr
|
2 |
import torch
|
3 |
-
from diffusers import StableDiffusionControlNetPipeline, ControlNetModel
|
|
|
4 |
import os
|
5 |
from huggingface_hub import login
|
6 |
|
7 |
-
#
|
8 |
token = os.getenv("HF_TOKEN")
|
9 |
login(token=token)
|
10 |
|
11 |
-
#
|
12 |
-
model_id = "
|
13 |
-
controlnet_id = "lllyasviel/control_v11p_sd15_inpaint" #
|
14 |
|
15 |
-
# Load ControlNet
|
16 |
controlnet = ControlNetModel.from_pretrained(controlnet_id, torch_dtype=torch.float32)
|
17 |
-
|
18 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
19 |
controlnet=controlnet,
|
20 |
-
|
|
|
21 |
)
|
22 |
-
pipeline = pipeline.to("cuda") if torch.cuda.is_available() else pipeline
|
23 |
|
|
|
24 |
|
25 |
-
# Define
|
26 |
def generate_image(prompt, reference_image):
|
27 |
-
# Ensure the reference image is in the correct format
|
28 |
reference_image = reference_image.convert("RGB").resize((512, 512))
|
29 |
-
|
30 |
-
# Generate the image with ControlNet
|
31 |
generated_image = pipeline(
|
32 |
prompt=prompt,
|
33 |
image=reference_image,
|
@@ -37,7 +47,6 @@ def generate_image(prompt, reference_image):
|
|
37 |
).images[0]
|
38 |
return generated_image
|
39 |
|
40 |
-
|
41 |
# Set up Gradio interface
|
42 |
interface = gr.Interface(
|
43 |
fn=generate_image,
|
@@ -47,7 +56,7 @@ interface = gr.Interface(
|
|
47 |
],
|
48 |
outputs="image",
|
49 |
title="Image Generation with Reference-Only Style Transfer",
|
50 |
-
description="Generate an image based on a text prompt and style reference image using Stable Diffusion
|
51 |
)
|
52 |
|
53 |
# Launch the Gradio interface
|
|
|
1 |
import gradio as gr
|
2 |
import torch
|
3 |
+
from diffusers import StableDiffusionControlNetPipeline, ControlNetModel, UNet2DConditionModel, AutoencoderKL, DDIMScheduler
|
4 |
+
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
|
5 |
import os
|
6 |
from huggingface_hub import login
|
7 |
|
8 |
+
# Authenticate with Hugging Face
|
9 |
token = os.getenv("HF_TOKEN")
|
10 |
login(token=token)
|
11 |
|
12 |
+
# Define model and controlnet IDs
|
13 |
+
model_id = "runwayml/stable-diffusion-v1-5" # Use a fully compatible model
|
14 |
+
controlnet_id = "lllyasviel/control_v11p_sd15_inpaint" # ControlNet variant
|
15 |
|
16 |
+
# Load ControlNet and other components
|
17 |
controlnet = ControlNetModel.from_pretrained(controlnet_id, torch_dtype=torch.float32)
|
18 |
+
unet = UNet2DConditionModel.from_pretrained(model_id, subfolder="unet")
|
19 |
+
vae = AutoencoderKL.from_pretrained(model_id, subfolder="vae")
|
20 |
+
scheduler = DDIMScheduler.from_pretrained(model_id, subfolder="scheduler")
|
21 |
+
text_encoder = CLIPTextModel.from_pretrained("openai/clip-vit-base-patch32")
|
22 |
+
tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-base-patch32")
|
23 |
+
feature_extractor = CLIPFeatureExtractor.from_pretrained("openai/clip-vit-base-patch32")
|
24 |
+
|
25 |
+
# Initialize the pipeline with all required components
|
26 |
+
pipeline = StableDiffusionControlNetPipeline(
|
27 |
+
vae=vae,
|
28 |
+
text_encoder=text_encoder,
|
29 |
+
tokenizer=tokenizer,
|
30 |
+
unet=unet,
|
31 |
controlnet=controlnet,
|
32 |
+
scheduler=scheduler,
|
33 |
+
feature_extractor=feature_extractor
|
34 |
)
|
|
|
35 |
|
36 |
+
pipeline = pipeline.to("cuda") if torch.cuda.is_available() else pipeline
|
37 |
|
38 |
+
# Define Gradio interface function
|
39 |
def generate_image(prompt, reference_image):
|
|
|
40 |
reference_image = reference_image.convert("RGB").resize((512, 512))
|
|
|
|
|
41 |
generated_image = pipeline(
|
42 |
prompt=prompt,
|
43 |
image=reference_image,
|
|
|
47 |
).images[0]
|
48 |
return generated_image
|
49 |
|
|
|
50 |
# Set up Gradio interface
|
51 |
interface = gr.Interface(
|
52 |
fn=generate_image,
|
|
|
56 |
],
|
57 |
outputs="image",
|
58 |
title="Image Generation with Reference-Only Style Transfer",
|
59 |
+
description="Generate an image based on a text prompt and style reference image using Stable Diffusion with ControlNet."
|
60 |
)
|
61 |
|
62 |
# Launch the Gradio interface
|