ssboost commited on
Commit
d0ce980
·
verified ·
1 Parent(s): e5c7871

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +197 -1
app.py CHANGED
@@ -1,2 +1,198 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import os
2
- exec(os.environ.get('APP'))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import spaces
2
+ import random
3
+ import torch
4
+ from huggingface_hub import snapshot_download
5
+ from transformers import CLIPVisionModelWithProjection, CLIPImageProcessor
6
+ from kolors.pipelines import pipeline_stable_diffusion_xl_chatglm_256_ipadapter, pipeline_stable_diffusion_xl_chatglm_256
7
+ from kolors.models.modeling_chatglm import ChatGLMModel
8
+ from kolors.models.tokenization_chatglm import ChatGLMTokenizer
9
+ from kolors.models import unet_2d_condition
10
+ from diffusers import AutoencoderKL, EulerDiscreteScheduler, UNet2DConditionModel
11
+ import gradio as gr
12
+ import numpy as np
13
+ from huggingface_hub import InferenceClient
14
  import os
15
+
16
+ # Cohere 모델 초기화
17
+ client = InferenceClient("CohereForAI/c4ai-command-r-plus", token=os.getenv("HF_TOKEN"))
18
+
19
+ device = "cuda"
20
+ ckpt_dir = snapshot_download(repo_id="Kwai-Kolors/Kolors")
21
+ ckpt_IPA_dir = snapshot_download(repo_id="Kwai-Kolors/Kolors-IP-Adapter-Plus")
22
+
23
+ text_encoder = ChatGLMModel.from_pretrained(f'{ckpt_dir}/text_encoder', torch_dtype=torch.float16).half().to(device)
24
+ tokenizer = ChatGLMTokenizer.from_pretrained(f'{ckpt_dir}/text_encoder')
25
+ vae = AutoencoderKL.from_pretrained(f"{ckpt_dir}/vae", revision=None).half().to(device)
26
+ scheduler = EulerDiscreteScheduler.from_pretrained(f"{ckpt_dir}/scheduler")
27
+ unet_t2i = UNet2DConditionModel.from_pretrained(f"{ckpt_dir}/unet", revision=None).half().to(device)
28
+ unet_i2i = unet_2d_condition.UNet2DConditionModel.from_pretrained(f"{ckpt_dir}/unet", revision=None).half().to(device)
29
+ image_encoder = CLIPVisionModelWithProjection.from_pretrained(f'{ckpt_IPA_dir}/image_encoder',ignore_mismatched_sizes=True).to(dtype=torch.float16, device=device)
30
+ ip_img_size = 336
31
+ clip_image_processor = CLIPImageProcessor(size=ip_img_size, crop_size=ip_img_size)
32
+
33
+ pipe_t2i = pipeline_stable_diffusion_xl_chatglm_256.StableDiffusionXLPipeline(
34
+ vae=vae,
35
+ text_encoder=text_encoder,
36
+ tokenizer=tokenizer,
37
+ unet=unet_t2i,
38
+ scheduler=scheduler,
39
+ force_zeros_for_empty_prompt=False
40
+ ).to(device)
41
+
42
+ pipe_i2i = pipeline_stable_diffusion_xl_chatglm_256_ipadapter.StableDiffusionXLPipeline(
43
+ vae=vae,
44
+ text_encoder=text_encoder,
45
+ tokenizer=tokenizer,
46
+ unet=unet_i2i,
47
+ scheduler=scheduler,
48
+ image_encoder=image_encoder,
49
+ feature_extractor=clip_image_processor,
50
+ force_zeros_for_empty_prompt=False
51
+ ).to(device)
52
+
53
+ if hasattr(pipe_i2i.unet, 'encoder_hid_proj'):
54
+ pipe_i2i.unet.text_encoder_hid_proj = pipe_i2i.unet.encoder_hid_proj
55
+
56
+ pipe_i2i.load_ip_adapter(f'{ckpt_IPA_dir}' , subfolder="", weight_name=["ip_adapter_plus_general.bin"])
57
+
58
+ MAX_SEED = np.iinfo(np.int32).max
59
+ MAX_IMAGE_SIZE = 1024
60
+
61
+ @spaces.GPU
62
+ def infer(prompt,
63
+ ip_adapter_image = None,
64
+ ip_adapter_scale = 0.5,
65
+ negative_prompt = "",
66
+ seed = 0,
67
+ randomize_seed = False,
68
+ width = 1024,
69
+ height = 1024,
70
+ guidance_scale = 5.0,
71
+ num_inference_steps = 25
72
+ ):
73
+ if randomize_seed:
74
+ seed = random.randint(0, MAX_SEED)
75
+ generator = torch.Generator().manual_seed(seed)
76
+
77
+ if ip_adapter_image is None:
78
+ pipe_t2i.to(device)
79
+ image = pipe_t2i(
80
+ prompt = prompt,
81
+ negative_prompt = negative_prompt,
82
+ guidance_scale = guidance_scale,
83
+ num_inference_steps = num_inference_steps,
84
+ width = width,
85
+ height = height,
86
+ generator = generator
87
+ ).images[0]
88
+ image.save("generated_image.jpg") # 파일 확장자를 .jpg로 변경
89
+ return image, "generated_image.jpg"
90
+ else:
91
+ pipe_i2i.to(device)
92
+ image_encoder.to(device)
93
+ pipe_i2i.image_encoder = image_encoder
94
+ pipe_i2i.set_ip_adapter_scale([ip_adapter_scale])
95
+ image = pipe_i2i(
96
+ prompt=prompt,
97
+ ip_adapter_image=[ip_adapter_image],
98
+ negative_prompt=negative_prompt,
99
+ height=height,
100
+ width=width,
101
+ num_inference_steps=num_inference_steps,
102
+ guidance_scale=guidance_scale,
103
+ num_images_per_prompt=1,
104
+ generator=generator
105
+ ).images[0]
106
+ image.save("generated_image.jpg") # 파일 확장자를 .jpg로 변경
107
+ return image, "generated_image.jpg"
108
+
109
+ css="""
110
+ #col-left {
111
+ margin: 0 auto;
112
+ max-width: 600px;
113
+ }
114
+ #col-right {
115
+ margin: 0 auto;
116
+ max-width: 750px;
117
+ }
118
+ """
119
+
120
+ with gr.Blocks(css=css) as Kolors:
121
+ with gr.Row():
122
+ with gr.Column(elem_id="col-left"):
123
+ with gr.Row():
124
+ generated_prompt = gr.Textbox(
125
+ label="프롬프트 입력",
126
+ placeholder="이미지 생성에 사용할 프롬프트를 입력하세요",
127
+ lines=2
128
+ )
129
+ with gr.Row():
130
+ ip_adapter_image = gr.Image(label="Image Prompt (optional)", type="pil")
131
+ with gr.Row(visible=False): # Advanced Settings 숨김
132
+ negative_prompt = gr.Textbox(
133
+ label="Negative prompt",
134
+ placeholder="Enter a negative prompt",
135
+ visible=True,
136
+ )
137
+ seed = gr.Slider(
138
+ label="Seed",
139
+ minimum=0,
140
+ maximum=MAX_SEED,
141
+ step=1,
142
+ value=0,
143
+ )
144
+ randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
145
+ with gr.Row():
146
+ width = gr.Slider(
147
+ label="Width",
148
+ minimum=256,
149
+ maximum=MAX_IMAGE_SIZE,
150
+ step=32,
151
+ value=1024,
152
+ )
153
+ height = gr.Slider(
154
+ label="Height",
155
+ minimum=256,
156
+ maximum=MAX_IMAGE_SIZE,
157
+ step=32,
158
+ value=1024,
159
+ )
160
+ with gr.Row():
161
+ guidance_scale = gr.Slider(
162
+ label="Guidance scale",
163
+ minimum=0.0,
164
+ maximum=10.0,
165
+ step=0.1,
166
+ value=5.0,
167
+ )
168
+ num_inference_steps = gr.Slider(
169
+ label="Number of inference steps",
170
+ minimum=10,
171
+ maximum=50,
172
+ step=1,
173
+ value=25,
174
+ )
175
+ with gr.Row():
176
+ ip_adapter_scale = gr.Slider(
177
+ label="Image influence scale",
178
+ info="Use 1 for creating variations",
179
+ minimum=0.0,
180
+ maximum=1.0,
181
+ step=0.05,
182
+ value=0.5,
183
+ )
184
+ with gr.Row():
185
+ run_button = gr.Button("Generate Image")
186
+
187
+ with gr.Column(elem_id="col-right"):
188
+ result = gr.Image(label="Result", show_label=False)
189
+ download_button = gr.File(label="Download Image")
190
+
191
+ # 이미지 생성 및 다운로드 파일 경로 설정
192
+ run_button.click(
193
+ fn=infer,
194
+ inputs=[generated_prompt, ip_adapter_image, ip_adapter_scale, negative_prompt, seed, randomize_seed, width, height, guidance_scale, num_inference_steps],
195
+ outputs=[result, download_button]
196
+ )
197
+
198
+ Kolors.queue().launch(debug=True)