vivekvar commited on
Commit
0fd6850
·
verified ·
1 Parent(s): 7d605fc

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +58 -0
app.py ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ from PIL import Image
4
+ from diffusers import StableDiffusionControlNetPipeline, ControlNetModel
5
+ import gradio as gr
6
+
7
+ # Disable oneDNN custom operations
8
+ os.environ['TF_ENABLE_ONEDNN_OPTS'] = '0'
9
+
10
+ # Clear PyTorch cache
11
+ torch.cuda.empty_cache()
12
+
13
+ # Check if CUDA is available
14
+ device = "cuda" if torch.cuda.is_available() else "cpu"
15
+ if device == "cuda":
16
+ print("CUDA is available. Device count:", torch.cuda.device_count())
17
+ print("Current device:", torch.cuda.current_device())
18
+ print("Device name:", torch.cuda.get_device_name(torch.cuda.current_device()))
19
+ else:
20
+ print("CUDA is not available. Using CPU.")
21
+
22
+ # Load ControlNet model with OpenPose pre-trained weights from Hugging Face
23
+ controlnet = ControlNetModel.from_pretrained("lllyasviel/control_v11p_sd15_openpose", torch_dtype=torch.float16)
24
+
25
+ # Load the Stable Diffusion model
26
+ pipe = StableDiffusionControlNetPipeline.from_pretrained(
27
+ "runwayml/stable-diffusion-v1-5", controlnet=controlnet, torch_dtype=torch.float16
28
+ ).to(device)
29
+
30
+ # Function for inference
31
+ def generate_image(prompt, target_image, pose_image):
32
+ try:
33
+ # Resize images
34
+ target_image = target_image.resize((512, 512))
35
+ pose_image = pose_image.resize((512, 512))
36
+
37
+ # Generate image with ControlNet
38
+ output = pipe(prompt=prompt, image=target_image, control_image=pose_image, num_inference_steps=50)
39
+
40
+ # Return the result
41
+ return output["sample"][0]
42
+ except Exception as e:
43
+ print(f"Error during image generation: {e}")
44
+ return None
45
+
46
+ # Setup Gradio Interface
47
+ interface = gr.Interface(
48
+ fn=generate_image,
49
+ inputs=[
50
+ gr.Textbox(label="Prompt"),
51
+ gr.Image(label="Target Image", type="pil"),
52
+ gr.Image(label="Pose Image (Reference)", type="pil")
53
+ ],
54
+ outputs=gr.Image(label="Generated Image")
55
+ )
56
+
57
+ # Launch the interface
58
+ interface.launch()