mpatel57 commited on
Commit
a5eed04
1 Parent(s): 6d0ad4a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +120 -99
app.py CHANGED
@@ -15,7 +15,7 @@ from transformers import (
15
  CLIPTextModelWithProjection,
16
  CLIPVisionModelWithProjection,
17
  CLIPImageProcessor,
18
- CLIPTokenizer
19
  )
20
 
21
  from transformers import CLIPTokenizer
@@ -33,10 +33,11 @@ if torch.cuda.is_available():
33
  __device__ = "cuda"
34
  __dtype__ = torch.float16
35
 
 
36
  class Model:
37
  def __init__(self):
38
  self.device = __device__
39
-
40
  self.text_encoder = (
41
  CLIPTextModelWithProjection.from_pretrained(
42
  "laion/CLIP-ViT-bigG-14-laion2B-39B-b160k",
@@ -65,102 +66,48 @@ class Model:
65
  self.pipe = DiffusionPipeline.from_pretrained(
66
  "kandinsky-community/kandinsky-2-2-decoder", torch_dtype=__dtype__
67
  ).to(self.device)
68
-
69
- def inference(self, raw_data):
 
70
  image_emb, negative_image_emb = self.pipe_prior(
71
  raw_data=raw_data,
 
72
  ).to_tuple()
73
  image = self.pipe(
74
  image_embeds=image_emb,
75
  negative_image_embeds=negative_image_emb,
76
  num_inference_steps=50,
77
- guidance_scale=4.0,
 
78
  ).images[0]
79
  return image
80
-
81
- def process_data(self,
82
- image: PIL.Image.Image,
83
- keyword: str,
84
- image2: PIL.Image.Image,
85
- keyword2: str,
86
- text: str,
87
- ) -> dict[str, Any]:
88
- print(f"keyword : {keyword}, keyword2 : {keyword2}, prompt : {text}")
89
- device = torch.device(self.device)
90
- data: dict[str, Any] = {}
91
- data['text'] = text
92
-
93
- txt = self.tokenizer(
94
- text,
95
- padding='max_length',
96
- truncation=True,
97
- return_tensors='pt',
98
- )
99
- txt_items = {k: v.to(device) for k, v in txt.items()}
100
- new_feats = self.text_encoder(**txt_items)
101
- new_last_hidden_states = new_feats.last_hidden_state[0].cpu().numpy()
102
-
103
- plt.imshow(image)
104
- plt.title('image')
105
- plt.savefig('image_testt2.png')
106
- plt.show()
107
-
108
- mask_img = self.image_processor(image, return_tensors="pt").to(__device__)
109
- vision_feats = self.vision_encoder(
110
- **mask_img
111
- ).image_embeds
112
-
113
- entity_tokens = self.tokenizer(keyword)["input_ids"][1:-1]
114
- for tid in entity_tokens:
115
- indices = np.where(txt_items["input_ids"][0].cpu().numpy() == tid)[0]
116
- new_last_hidden_states[indices] = vision_feats[0].cpu().numpy()
117
- print(indices)
118
-
119
- if image2 is not None:
120
- mask_img2 = self.image_processor(image2, return_tensors="pt").to(__device__)
121
- vision_feats2 = self.vision_encoder(
122
- **mask_img2
123
- ).image_embeds
124
- if keyword2 is not None:
125
- entity_tokens = self.tokenizer(keyword2)["input_ids"][1:-1]
126
- for tid in entity_tokens:
127
- indices = np.where(txt_items["input_ids"][0].cpu().numpy() == tid)[0]
128
- new_last_hidden_states[indices] = vision_feats2[0].cpu().numpy()
129
- print(indices)
130
-
131
- text_feats = {
132
- "prompt_embeds": new_feats.text_embeds.to(__device__),
133
- "text_encoder_hidden_states": torch.tensor(new_last_hidden_states).unsqueeze(0).to(__device__),
134
- "text_mask": txt_items["attention_mask"].to(__device__),
135
- }
136
- return text_feats
137
-
138
- def run(self,
139
- image: dict[str, PIL.Image.Image],
140
- keyword: str,
141
- image2: dict[str, PIL.Image.Image],
142
- keyword2: str,
143
- text: str,
144
- ):
145
-
146
- # aug_feats = self.process_data(image["composite"], keyword, image2["composite"], keyword2, text)
147
  sub_imgs = [image["composite"]]
148
- if image2:
149
- sub_imgs.append(image2["composite"])
150
  sun_keywords = [keyword]
151
- if keyword2:
152
  sun_keywords.append(keyword2)
 
 
153
  raw_data = {
154
  "prompt": text,
155
  "subject_images": sub_imgs,
156
- "subject_keywords": sun_keywords
157
  }
158
- image = self.inference(raw_data)
159
  return image
160
 
161
- def create_demo():
162
 
163
- USAGE = '''## To run the demo, you should:
 
164
  1. Upload your image.
165
  2. <span style='color: red;'>**Upload a masked subject image with white blankspace or whiten out manually using brush tool.**
166
  3. Input a Keyword i.e. 'Dog'
@@ -169,7 +116,7 @@ def create_demo():
169
  4-2. Input the Keyword i.e. 'Sunglasses'
170
  3. Input proper text prompts, such as "A photo of Dog" or "A Dog wearing sunglasses", Please use the same keyword in the prompt.
171
  4. Click the Run button.
172
- '''
173
 
174
  model = Model()
175
 
@@ -180,6 +127,8 @@ def create_demo():
180
 
181
  <p style="text-align: center; color: red;">This demo is currently hosted on either a small GPU or CPU. We will soon provide high-end GPU support.</p>
182
  <p style="text-align: center; color: red;">Please follow the instructions from here to run it locally: <a href="https://github.com/eclipse-t2i/lambda-eclipse-inference">GitHub Inference Code</a></p>
 
 
183
  """
184
  )
185
  gr.Markdown(USAGE)
@@ -187,28 +136,41 @@ def create_demo():
187
  with gr.Column():
188
  with gr.Group():
189
  gr.Markdown(
190
- 'Upload your first masked subject image or mask out marginal space')
191
- image = gr.ImageEditor(label='Input', type='pil', brush=gr.Brush(colors=["#FFFFFF"], color_mode="fixed"))
 
 
 
 
 
192
  keyword = gr.Text(
193
- label='Keyword',
194
  placeholder='e.g. "Dog", "Goofie"',
195
- info='Keyword for first subject')
 
196
  gr.Markdown(
197
- 'For Multi-Subject generation : Upload your second masked subject image or mask out marginal space')
198
- image2 = gr.ImageEditor(label='Input', type='pil', brush=gr.Brush(colors=["#FFFFFF"], color_mode="fixed"))
199
- keyword2= gr.Text(
200
- label='Keyword',
 
 
 
 
 
201
  placeholder='e.g. "Sunglasses", "Grand Canyon"',
202
- info='Keyword for second subject')
 
203
  prompt = gr.Text(
204
- label='Prompt',
205
  placeholder='e.g. "A photo of dog", "A dog wearing sunglasses"',
206
- info='Keep the keywords used previously in the prompt')
 
207
 
208
- run_button = gr.Button('Run')
209
 
210
  with gr.Column():
211
- result = gr.Image(label='Result')
212
 
213
  inputs = [
214
  image,
@@ -217,18 +179,77 @@ def create_demo():
217
  keyword2,
218
  prompt,
219
  ]
220
-
221
  gr.Examples(
222
- examples=[[os.path.join(os.path.dirname(__file__), "./assets/cat.png"), "cat", os.path.join(os.path.dirname(__file__), "./assets/blue_sunglasses.png"), "glasses", "A cat wearing glasses on a snowy field"]],
223
- inputs = inputs,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
224
  fn=model.run,
225
  outputs=result,
226
  )
227
-
228
  run_button.click(fn=model.run, inputs=inputs, outputs=result)
229
  return demo
230
 
231
 
232
- if __name__ == '__main__':
233
  demo = create_demo()
234
- demo.queue(max_size=20).launch()
 
15
  CLIPTextModelWithProjection,
16
  CLIPVisionModelWithProjection,
17
  CLIPImageProcessor,
18
+ CLIPTokenizer,
19
  )
20
 
21
  from transformers import CLIPTokenizer
 
33
  __device__ = "cuda"
34
  __dtype__ = torch.float16
35
 
36
+
37
  class Model:
38
  def __init__(self):
39
  self.device = __device__
40
+
41
  self.text_encoder = (
42
  CLIPTextModelWithProjection.from_pretrained(
43
  "laion/CLIP-ViT-bigG-14-laion2B-39B-b160k",
 
66
  self.pipe = DiffusionPipeline.from_pretrained(
67
  "kandinsky-community/kandinsky-2-2-decoder", torch_dtype=__dtype__
68
  ).to(self.device)
69
+
70
+ def inference(self, raw_data, seed):
71
+ generator = torch.Generator(device="cuda").manual_seed(seed)
72
  image_emb, negative_image_emb = self.pipe_prior(
73
  raw_data=raw_data,
74
+ generator=generator,
75
  ).to_tuple()
76
  image = self.pipe(
77
  image_embeds=image_emb,
78
  negative_image_embeds=negative_image_emb,
79
  num_inference_steps=50,
80
+ guidance_scale=7.5,
81
+ generator=generator,
82
  ).images[0]
83
  return image
84
+
85
+ def run(
86
+ self,
87
+ image: dict[str, PIL.Image.Image],
88
+ keyword: str,
89
+ image2: dict[str, PIL.Image.Image],
90
+ keyword2: str,
91
+ text: str,
92
+ seed: int,
93
+ ):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
94
  sub_imgs = [image["composite"]]
 
 
95
  sun_keywords = [keyword]
96
+ if keyword2 and keyword2 != "no subject":
97
  sun_keywords.append(keyword2)
98
+ if image2:
99
+ sub_imgs.append(image2["composite"])
100
  raw_data = {
101
  "prompt": text,
102
  "subject_images": sub_imgs,
103
+ "subject_keywords": sun_keywords,
104
  }
105
+ image = self.inference(raw_data, seed)
106
  return image
107
 
 
108
 
109
+ def create_demo():
110
+ USAGE = """## To run the demo, you should:
111
  1. Upload your image.
112
  2. <span style='color: red;'>**Upload a masked subject image with white blankspace or whiten out manually using brush tool.**
113
  3. Input a Keyword i.e. 'Dog'
 
116
  4-2. Input the Keyword i.e. 'Sunglasses'
117
  3. Input proper text prompts, such as "A photo of Dog" or "A Dog wearing sunglasses", Please use the same keyword in the prompt.
118
  4. Click the Run button.
119
+ """
120
 
121
  model = Model()
122
 
 
127
 
128
  <p style="text-align: center; color: red;">This demo is currently hosted on either a small GPU or CPU. We will soon provide high-end GPU support.</p>
129
  <p style="text-align: center; color: red;">Please follow the instructions from here to run it locally: <a href="https://github.com/eclipse-t2i/lambda-eclipse-inference">GitHub Inference Code</a></p>
130
+
131
+ <a href="https://colab.research.google.com/drive/1VcqzXZmilntec3AsIyzCqlstEhX4Pa1o?usp=sharing" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>
132
  """
133
  )
134
  gr.Markdown(USAGE)
 
136
  with gr.Column():
137
  with gr.Group():
138
  gr.Markdown(
139
+ "Upload your first masked subject image or mask out marginal space"
140
+ )
141
+ image = gr.ImageEditor(
142
+ label="Input",
143
+ type="pil",
144
+ brush=gr.Brush(colors=["#FFFFFF"], color_mode="fixed"),
145
+ )
146
  keyword = gr.Text(
147
+ label="Keyword",
148
  placeholder='e.g. "Dog", "Goofie"',
149
+ info="Keyword for first subject",
150
+ )
151
  gr.Markdown(
152
+ "For Multi-Subject generation : Upload your second masked subject image or mask out marginal space"
153
+ )
154
+ image2 = gr.ImageEditor(
155
+ label="Input",
156
+ type="pil",
157
+ brush=gr.Brush(colors=["#FFFFFF"], color_mode="fixed"),
158
+ )
159
+ keyword2 = gr.Text(
160
+ label="Keyword",
161
  placeholder='e.g. "Sunglasses", "Grand Canyon"',
162
+ info="Keyword for second subject",
163
+ )
164
  prompt = gr.Text(
165
+ label="Prompt",
166
  placeholder='e.g. "A photo of dog", "A dog wearing sunglasses"',
167
+ info="Keep the keywords used previously in the prompt",
168
+ )
169
 
170
+ run_button = gr.Button("Run")
171
 
172
  with gr.Column():
173
+ result = gr.Image(label="Result")
174
 
175
  inputs = [
176
  image,
 
179
  keyword2,
180
  prompt,
181
  ]
182
+
183
  gr.Examples(
184
+ examples=[
185
+ [
186
+ os.path.join(os.path.dirname(__file__), "./assets/luffy.jpg"),
187
+ "luffy",
188
+ os.path.join(os.path.dirname(__file__), "./assets/white.jpg"),
189
+ "no subject",
190
+ "luffy holding a sword",
191
+ ],
192
+ [
193
+ os.path.join(os.path.dirname(__file__), "./assets/luffy.jpg"),
194
+ "luffy",
195
+ os.path.join(os.path.dirname(__file__), "./assets/white.jpg"),
196
+ "no subject",
197
+ "luffy in the living room",
198
+ ],
199
+ [
200
+ os.path.join(os.path.dirname(__file__), "./assets/teapot.jpg"),
201
+ "teapot",
202
+ os.path.join(os.path.dirname(__file__), "./assets/white.jpg"),
203
+ "no subject",
204
+ "teapot on a cobblestone street",
205
+ ],
206
+ [
207
+ os.path.join(os.path.dirname(__file__), "./assets/trex.jpg"),
208
+ "trex",
209
+ os.path.join(os.path.dirname(__file__), "./assets/white.jpg"),
210
+ "no subject",
211
+ "trex near a river",
212
+ ],
213
+ [
214
+ os.path.join(os.path.dirname(__file__), "./assets/cat.png"),
215
+ "cat",
216
+ os.path.join(
217
+ os.path.dirname(__file__), "./assets/blue_sunglasses.png"
218
+ ),
219
+ "glasses",
220
+ "A cat wearing glasses on a snowy field",
221
+ ],
222
+ [
223
+ os.path.join(os.path.dirname(__file__), "./assets/statue.jpg"),
224
+ "statue",
225
+ os.path.join(os.path.dirname(__file__), "./assets/toilet.jpg"),
226
+ "toilet",
227
+ "statue sitting on a toilet",
228
+ ],
229
+ [
230
+ os.path.join(os.path.dirname(__file__), "./assets/teddy.jpg"),
231
+ "teddy",
232
+ os.path.join(os.path.dirname(__file__), "./assets/luffy_hat.jpg"),
233
+ "hat",
234
+ "a teddy wearing the hat at a beach",
235
+ ],
236
+ [
237
+ os.path.join(os.path.dirname(__file__), "./assets/chair.jpg"),
238
+ "chair",
239
+ os.path.join(os.path.dirname(__file__), "./assets/table.jpg"),
240
+ "table",
241
+ "a chair and table in living room",
242
+ ],
243
+ ],
244
+ inputs=inputs,
245
  fn=model.run,
246
  outputs=result,
247
  )
248
+
249
  run_button.click(fn=model.run, inputs=inputs, outputs=result)
250
  return demo
251
 
252
 
253
+ if __name__ == "__main__":
254
  demo = create_demo()
255
+ demo.queue(max_size=20).launch()