Garrett Goon commited on
Commit
1da6b3f
1 Parent(s): 7b17c3f

updated syntax to match reviewed repo

Browse files
__pycache__/utils.cpython-38.pyc CHANGED
Binary files a/__pycache__/utils.cpython-38.pyc and b/__pycache__/utils.cpython-38.pyc differ
 
app.py CHANGED
@@ -34,37 +34,37 @@ pipeline = StableDiffusionPipeline.from_pretrained(
34
  CONCEPT_PATH = pathlib.Path("learned_embeddings_dict.pt")
35
  learned_embeddings_dict = torch.load(CONCEPT_PATH)
36
 
37
- concept_to_dummy_tokens_map = {}
38
  for concept_token, embedding_dict in learned_embeddings_dict.items():
39
- initializer_tokens = embedding_dict["initializer_tokens"]
40
  learned_embeddings = embedding_dict["learned_embeddings"]
41
  (
42
  initializer_ids,
43
  dummy_placeholder_ids,
44
- dummy_placeholder_tokens,
45
  ) = utils.add_new_tokens_to_tokenizer(
46
- concept_token=concept_token,
47
- initializer_tokens=initializer_tokens,
48
  tokenizer=pipeline.tokenizer,
49
  )
50
  pipeline.text_encoder.resize_token_embeddings(len(pipeline.tokenizer))
51
  token_embeddings = pipeline.text_encoder.get_input_embeddings().weight.data
52
  for d_id, tensor in zip(dummy_placeholder_ids, learned_embeddings):
53
  token_embeddings[d_id] = tensor
54
- concept_to_dummy_tokens_map[concept_token] = dummy_placeholder_tokens
55
 
56
 
57
- def replace_concept_tokens(text: str):
58
- for concept_token, dummy_tokens in concept_to_dummy_tokens_map.items():
59
- text = text.replace(concept_token, dummy_tokens)
60
  return text
61
 
62
  def inference(prompt: str, guidance_scale: int, num_inference_steps: int, seed: int):
63
  if not prompt:
64
  raise ValueError("Please enter a prompt.")
65
- if '<det-logo>' not in prompt:
66
- raise ValueError('"<det-logo>" must be included in the prompt.')
67
- prompt = replace_concept_tokens(prompt)
68
  generator = torch.Generator(device=device).manual_seed(seed)
69
  output = pipeline(
70
  prompt=[prompt] * BATCH_SIZE,
@@ -275,35 +275,35 @@ block = gr.Blocks(css=css)
275
 
276
  examples = [
277
  [
278
- "a Van Gogh painting of a <det-logo> with thick strokes, masterful composition",
279
  # 4,
280
  # 45,
281
  # 7.5,
282
  # 1024,
283
  ],
284
  [
285
- "Futuristic <det-logo> in a desert, painting, octane render, 4 k, anime sky, warm colors",
286
  # 4,
287
  # 45,
288
  # 7,
289
  # 1024,
290
  ],
291
  [
292
- "cell shaded cartoon of a <det-logo>, subtle colors, post grunge, concept art by josan gonzales and wlop, by james jean, victo ngai, david rubin, mike mignola, deviantart, art by artgem",
293
  # 4,
294
  # 45,
295
  # 7,
296
  # 1024,
297
  ],
298
  [
299
- "a surreal Salvador Dali painting of a <det-logo>, soft blended colors",
300
  # 4,
301
  # 45,
302
  # 7,
303
  # 1024,
304
  ],
305
  [
306
- "Beautiful tarot illustration of a <det-logo>, in the style of james jean and victo ngai, mystical colors, trending on artstation",
307
  # 4,
308
  # 45,
309
  # 7,
@@ -334,10 +334,10 @@ with block:
334
  with gr.Box():
335
  with gr.Row(elem_id="prompt-container").style(equal_height=True):
336
  prompt = gr.Textbox(
337
- label='Enter a prompt including "<det-logo>"',
338
  show_label=False,
339
  max_lines=1,
340
- placeholder='Enter a prompt including "<det-logo>"',
341
  elem_id="prompt-text-input",
342
  ).style(
343
  container=False,
 
34
  CONCEPT_PATH = pathlib.Path("learned_embeddings_dict.pt")
35
  learned_embeddings_dict = torch.load(CONCEPT_PATH)
36
 
37
+ concept_to_dummy_strs_map = {}
38
  for concept_token, embedding_dict in learned_embeddings_dict.items():
39
+ initializer_strs = embedding_dict["initializer_strs"]
40
  learned_embeddings = embedding_dict["learned_embeddings"]
41
  (
42
  initializer_ids,
43
  dummy_placeholder_ids,
44
+ dummy_placeholder_strs,
45
  ) = utils.add_new_tokens_to_tokenizer(
46
+ concept_str=concept_token,
47
+ initializer_strs=initializer_strs,
48
  tokenizer=pipeline.tokenizer,
49
  )
50
  pipeline.text_encoder.resize_token_embeddings(len(pipeline.tokenizer))
51
  token_embeddings = pipeline.text_encoder.get_input_embeddings().weight.data
52
  for d_id, tensor in zip(dummy_placeholder_ids, learned_embeddings):
53
  token_embeddings[d_id] = tensor
54
+ concept_to_dummy_strs_map[concept_token] = dummy_placeholder_strs
55
 
56
 
57
+ def replace_concept_strs(text: str):
58
+ for concept_token, dummy_strs in concept_to_dummy_strs_map.items():
59
+ text = text.replace(concept_token, dummy_strs)
60
  return text
61
 
62
  def inference(prompt: str, guidance_scale: int, num_inference_steps: int, seed: int):
63
  if not prompt:
64
  raise ValueError("Please enter a prompt.")
65
+ if 'det-logo' not in prompt:
66
+ raise ValueError('"det-logo" must be included in the prompt.')
67
+ prompt = replace_concept_strs(prompt)
68
  generator = torch.Generator(device=device).manual_seed(seed)
69
  output = pipeline(
70
  prompt=[prompt] * BATCH_SIZE,
 
275
 
276
  examples = [
277
  [
278
+ "a Van Gogh painting of a det-logo with thick strokes, masterful composition",
279
  # 4,
280
  # 45,
281
  # 7.5,
282
  # 1024,
283
  ],
284
  [
285
+ "Futuristic det-logo in a desert, painting, octane render, 4 k, anime sky, warm colors",
286
  # 4,
287
  # 45,
288
  # 7,
289
  # 1024,
290
  ],
291
  [
292
+ "cell shaded cartoon of a det-logo, subtle colors, post grunge, concept art by josan gonzales and wlop, by james jean, victo ngai, david rubin, mike mignola, deviantart, art by artgem",
293
  # 4,
294
  # 45,
295
  # 7,
296
  # 1024,
297
  ],
298
  [
299
+ "a surreal Salvador Dali painting of a det-logo, soft blended colors",
300
  # 4,
301
  # 45,
302
  # 7,
303
  # 1024,
304
  ],
305
  [
306
+ "Beautiful tarot illustration of a det-logo, in the style of james jean and victo ngai, mystical colors, trending on artstation",
307
  # 4,
308
  # 45,
309
  # 7,
 
334
  with gr.Box():
335
  with gr.Row(elem_id="prompt-container").style(equal_height=True):
336
  prompt = gr.Textbox(
337
+ label='Enter a prompt including "det-logo"',
338
  show_label=False,
339
  max_lines=1,
340
+ placeholder='Enter a prompt including "det-logo"',
341
  elem_id="prompt-text-input",
342
  ).style(
343
  container=False,
learned_embeddings_dict.pt CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:73ab240e6ef7b16a70e14b4625882d8f63050f1d96ffc0eef6e0e0caa2844109
3
  size 16235
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:5184c747567ac6240bd45b701cb29416752fcc925b2a967a811c28729451b942
3
  size 16235
utils.py CHANGED
@@ -1,59 +1,49 @@
1
- from typing import List, Sequence, Tuple
2
 
3
  import torch
4
  import torch.nn as nn
5
 
6
 
7
  def add_new_tokens_to_tokenizer(
8
- concept_token: str,
9
- initializer_tokens: Sequence[str],
10
  tokenizer: nn.Module,
11
- ) -> Tuple[List[int], List[int], str]:
12
  """Helper function for adding new tokens to the tokenizer and extending the corresponding
13
  embeddings appropriately, given a single concept token and its sequence of corresponding
14
- initializer tokens. Returns the lists of ids for the initializer tokens and their dummy
15
  replacements, as well as the string representation of the dummies.
16
  """
 
 
 
 
17
  initializer_ids = tokenizer(
18
- initializer_tokens,
19
- padding="max_length",
20
- truncation=True,
21
- max_length=tokenizer.model_max_length,
22
  return_tensors="pt",
23
  add_special_tokens=False,
24
- ).input_ids
25
-
26
- try:
27
- special_token_ids = tokenizer.all_special_ids
28
- except AttributeError:
29
- special_token_ids = []
30
-
31
- non_special_initializer_locations = torch.isin(
32
- initializer_ids, torch.tensor(special_token_ids), invert=True
33
- )
34
- non_special_initializer_ids = initializer_ids[non_special_initializer_locations]
35
- if len(non_special_initializer_ids) == 0:
36
- raise ValueError(
37
- f'"{initializer_tokens}" maps to trivial tokens, please choose a different initializer.'
38
- )
39
 
40
  # Add a dummy placeholder token for every token in the initializer.
41
- dummy_placeholder_token_list = [
42
- f"{concept_token}_{n}" for n in range(len(non_special_initializer_ids))
43
- ]
44
- dummy_placeholder_tokens = " ".join(dummy_placeholder_token_list)
45
- num_added_tokens = tokenizer.add_tokens(dummy_placeholder_token_list)
46
- if num_added_tokens != len(dummy_placeholder_token_list):
47
- raise ValueError(
48
- f"Subset of {dummy_placeholder_token_list} tokens already exist in tokenizer."
49
- )
50
-
51
- dummy_placeholder_ids = tokenizer.convert_tokens_to_ids(
52
- dummy_placeholder_token_list
53
- )
54
- # Sanity check
55
  assert len(dummy_placeholder_ids) == len(
56
- non_special_initializer_ids
57
- ), 'Length of "dummy_placeholder_ids" and "non_special_initializer_ids" must match.'
 
 
 
58
 
59
- return non_special_initializer_ids, dummy_placeholder_ids, dummy_placeholder_tokens
 
 
 
1
+ from typing import List, Tuple
2
 
3
  import torch
4
  import torch.nn as nn
5
 
6
 
7
  def add_new_tokens_to_tokenizer(
8
+ concept_str: str,
9
+ initializer_strs: str,
10
  tokenizer: nn.Module,
11
+ ) -> Tuple[torch.Tensor, List[int], str]:
12
  """Helper function for adding new tokens to the tokenizer and extending the corresponding
13
  embeddings appropriately, given a single concept token and its sequence of corresponding
14
+ initializer tokens. Returns the tensor of ids for the initializer tokens and their dummy
15
  replacements, as well as the string representation of the dummies.
16
  """
17
+ assert not token_exists_in_tokenizer(
18
+ concept_str, tokenizer
19
+ ), f"concept_str {concept_str} already exists in tokenizer."
20
+
21
  initializer_ids = tokenizer(
22
+ initializer_strs,
 
 
 
23
  return_tensors="pt",
24
  add_special_tokens=False,
25
+ ).input_ids[0]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
26
 
27
  # Add a dummy placeholder token for every token in the initializer.
28
+ dummy_placeholder_str_list = [f"<{concept_str}>_{n}" for n in range(len(initializer_ids))]
29
+ # Sanity check.
30
+ for dummy in dummy_placeholder_str_list:
31
+ assert not token_exists_in_tokenizer(
32
+ dummy, tokenizer
33
+ ), f"dummy {dummy} already exists in tokenizer."
34
+
35
+ dummy_placeholder_strs = " ".join(dummy_placeholder_str_list)
36
+
37
+ tokenizer.add_tokens(dummy_placeholder_str_list)
38
+ dummy_placeholder_ids = tokenizer.convert_tokens_to_ids(dummy_placeholder_str_list)
39
+ # Sanity check that the dummies correspond to the correct number of ids.
 
 
40
  assert len(dummy_placeholder_ids) == len(
41
+ initializer_ids
42
+ ), 'Length of "dummy_placeholder_ids" and "initializer_ids" must match.'
43
+
44
+ return initializer_ids, dummy_placeholder_ids, dummy_placeholder_strs
45
+
46
 
47
+ def token_exists_in_tokenizer(token: str, tokenizer: nn.Module) -> bool:
48
+ exists = tokenizer.convert_tokens_to_ids([token]) != [tokenizer.unk_token_id]
49
+ return exists