svjack commited on
Commit
3ae8d58
1 Parent(s): b830f60

Upload control_app.py

Browse files
Files changed (1) hide show
  1. control_app.py +131 -0
control_app.py ADDED
@@ -0,0 +1,131 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ '''
2
+ !git clone https://huggingface.co/spaces/radames/SPIGA-face-alignment-headpose-estimator
3
+ !cp -r SPIGA-face-alignment-headpose-estimator/SPIGA .
4
+ !pip install -r SPIGA/requirements.txt
5
+ !pip install datasets
6
+ !huggingface-cli login
7
+ '''
8
+ from pred_color import *
9
+ import gradio as gr
10
+
11
+ from diffusers import (
12
+ AutoencoderKL,
13
+ ControlNetModel,
14
+ DDPMScheduler,
15
+ StableDiffusionControlNetPipeline,
16
+ UNet2DConditionModel,
17
+ UniPCMultistepScheduler,
18
+ )
19
+ import torch
20
+ from diffusers.utils import load_image
21
+
22
+ controlnet_model_name_or_path = "svjack/ControlNet-Face-Zh"
23
+ controlnet = ControlNetModel.from_pretrained(controlnet_model_name_or_path)
24
+ #controlnet = controlnet.to("cuda")
25
+
26
+ base_model_path = "IDEA-CCNL/Taiyi-Stable-Diffusion-1B-Chinese-v0.1"
27
+ pipe = StableDiffusionControlNetPipeline.from_pretrained(
28
+ base_model_path, controlnet=controlnet,
29
+ #torch_dtype=torch.float16
30
+ )
31
+
32
+ # speed up diffusion process with faster scheduler and memory optimization
33
+ pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config)
34
+ #pipe.enable_model_cpu_offload()
35
+ #pipe = pipe.to("cuda")
36
+
37
+ if torch.cuda.is_available():
38
+ pipe = pipe.to("cuda")
39
+ else:
40
+ #pipe.enable_model_cpu_offload()
41
+ pass
42
+
43
+ example_sample = [
44
+ ["Protector_Cromwell_style.png", "戴帽子穿灰色衣服的男子"]
45
+ ]
46
+
47
+ from PIL import Image
48
+ def pred_func(image, prompt):
49
+ out = single_pred_features(image)
50
+ if type(out) == type({}):
51
+ #return out["spiga_seg"]
52
+ control_image = out["spiga_seg"]
53
+ if type(image) == type("") and os.path.exists(image):
54
+ image = Image.open(image).convert("RGB")
55
+ elif hasattr(image, "shape"):
56
+ image = Image.fromarray(image).convert("RGB")
57
+ else:
58
+ image = image.convert("RGB")
59
+ image = image.resize((512, 512))
60
+
61
+ generator = torch.manual_seed(0)
62
+ image = pipe(
63
+ prompt, num_inference_steps=50,
64
+ generator=generator, image=control_image
65
+ ).images[0]
66
+ return control_image ,image
67
+
68
+
69
+ gr=gr.Interface(fn=pred_func, inputs=['image','text'],
70
+ outputs=[gr.Image(label='output').style(height=512),
71
+ gr.Image(label='output').style(height=512)],
72
+ examples=example_sample if example_sample else None,
73
+ )
74
+ gr.launch(share=False)
75
+
76
+ if __name__ == "__main__":
77
+ '''
78
+ control_image = load_image("./conditioning_image_1.png")
79
+ prompt = "戴眼镜的中年男子"
80
+ # generate image
81
+ generator = torch.manual_seed(0)
82
+ image = pipe(
83
+ prompt, num_inference_steps=50, generator=generator, image=control_image
84
+ ).images[0]
85
+ image
86
+
87
+ control_image = load_image("./conditioning_image_1.png")
88
+ prompt = "穿蓝色衣服的秃头男子"
89
+ # generate image
90
+ generator = torch.manual_seed(0)
91
+ image = pipe(
92
+ prompt, num_inference_steps=50, generator=generator, image=control_image
93
+ ).images[0]
94
+ image
95
+
96
+ control_image = load_image("./conditioning_image_2.png")
97
+ prompt = "金色头发的美丽女子"
98
+ # generate image
99
+ generator = torch.manual_seed(0)
100
+ image = pipe(
101
+ prompt, num_inference_steps=50, generator=generator, image=control_image
102
+ ).images[0]
103
+ image
104
+
105
+ control_image = load_image("./conditioning_image_2.png")
106
+ prompt = "绿色运动衫的男子"
107
+ # generate image
108
+ generator = torch.manual_seed(0)
109
+ image = pipe(
110
+ prompt, num_inference_steps=50, generator=generator, image=control_image
111
+ ).images[0]
112
+ image
113
+
114
+ from huggingface_hub import HfApi
115
+ hf_api = HfApi()
116
+
117
+ hf_api.upload_file(
118
+ path_or_fileobj = "TSD_save_only/diffusion_pytorch_model.bin",
119
+ path_in_repo = "diffusion_pytorch_model.bin",
120
+ repo_id = "svjack/ControlNet-Face-Zh",
121
+ repo_type = "model",
122
+ )
123
+
124
+ hf_api.upload_file(
125
+ path_or_fileobj = "TSD_save_only/config.json",
126
+ path_in_repo = "config.json",
127
+ repo_id = "svjack/ControlNet-Face-Zh",
128
+ repo_type = "model",
129
+ )
130
+ '''
131
+ pass