jaeikkim commited on
Commit
e80840a
·
1 Parent(s): db39f43

Cleanup binaries before space push

Browse files
MMaDA/inference/common.py CHANGED
@@ -57,14 +57,6 @@ def build_uni_prompting(cfg) -> Tuple[UniversalPrompting, AutoTokenizer]:
57
  cond_dropout_prob=cfg.training.cond_dropout_prob,
58
  use_reserved_token=True,
59
  )
60
- # Safety: if newer task tokens are missing (e.g., <|ti2ti|>, <|t2ti|>), inject them.
61
- for tok in ("<|ti2ti|>", "<|t2ti|>"):
62
- if tok not in uni_prompting.sptids_dict:
63
- token_id = tokenizer.convert_tokens_to_ids(tok)
64
- if token_id is None or token_id == tokenizer.unk_token_id:
65
- tokenizer.add_special_tokens({"additional_special_tokens": [tok]})
66
- token_id = tokenizer.convert_tokens_to_ids(tok)
67
- uni_prompting.sptids_dict[tok] = torch.tensor([token_id])
68
  return uni_prompting, tokenizer
69
 
70
 
 
57
  cond_dropout_prob=cfg.training.cond_dropout_prob,
58
  use_reserved_token=True,
59
  )
 
 
 
 
 
 
 
 
60
  return uni_prompting, tokenizer
61
 
62
 
MMaDA/inference/gradio_multimodal_demo_inst.py CHANGED
@@ -1259,79 +1259,79 @@ class OmadaDemo:
1259
  return None, "", f"Failed to encode source image: {exc}"
1260
 
1261
  text_tokens = max(4, min(int(text_tokens), self.max_text_len))
1262
- prompt_ids = self.uni_prompting.text_tokenizer(instruction_clean)['input_ids']
1263
- if isinstance(prompt_ids, list) and prompt_ids and isinstance(prompt_ids[0], list):
1264
- prompt_ids = prompt_ids[0]
1265
- if len(prompt_ids) == 0 or prompt_ids[0] != self.uni_prompting.text_tokenizer.bos_token_id:
1266
- prompt_ids = [self.uni_prompting.text_tokenizer.bos_token_id] + prompt_ids
1267
- prompt_ids = prompt_ids + [self.uni_prompting.text_tokenizer.eos_token_id]
1268
- prompt_tensor = torch.tensor(prompt_ids, device=self.device, dtype=torch.long)
1269
-
1270
- def _get_token(key: str):
1271
- tok = self.uni_prompting.sptids_dict.get(key)
1272
- if tok is None or tok.numel() == 0:
1273
- return None
1274
- return int(tok[0].item())
1275
-
1276
- ti2ti_id = _get_token('<|ti2ti|>')
1277
- soi_id = _get_token('<|soi|>')
1278
- eoi_id = _get_token('<|eoi|>')
1279
- if ti2ti_id is None or soi_id is None or eoi_id is None:
1280
- return None, "", "TI2TI special tokens are missing in the tokenizer/config."
1281
- pad_raw = getattr(self.uni_prompting, "pad_id", 0)
1282
- pad_id = int(pad_raw if pad_raw is not None else 0)
1283
-
1284
- img_placeholder = torch.full(
1285
- (self.image_seq_len,),
1286
  self.mask_token_id,
1287
  dtype=torch.long,
1288
  device=self.device,
1289
  )
1290
- text_placeholder = torch.full(
1291
- (text_tokens,),
1292
- self.mask_token_id,
1293
- dtype=torch.long,
1294
- device=self.device,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1295
  )
1296
 
1297
- src_flat = src_tokens.view(-1)
1298
- prompt_len = prompt_tensor.numel()
1299
- img_len = img_placeholder.numel()
1300
- text_len = text_placeholder.numel()
1301
-
1302
- prompt_start = 2 + src_flat.numel() + 1
1303
- prompt_end = prompt_start + prompt_len
1304
- img_start = prompt_end + 1
1305
- img_end = img_start + img_len
1306
- text_start = img_end + 1
1307
- text_end = text_start + text_len
1308
-
1309
- seq_parts = [
1310
- torch.tensor([ti2ti_id, soi_id], device=self.device, dtype=torch.long),
1311
- src_flat,
1312
- torch.tensor([eoi_id], device=self.device, dtype=torch.long),
1313
- prompt_tensor,
1314
- torch.tensor([soi_id], device=self.device, dtype=torch.long),
1315
- img_placeholder,
1316
- torch.tensor([eoi_id], device=self.device, dtype=torch.long),
1317
- text_placeholder,
1318
- ]
1319
- seq = torch.cat(seq_parts, dim=0).unsqueeze(0)
1320
- attn = torch.ones_like(seq, dtype=torch.long, device=self.device)
1321
-
1322
- uncond_seq = seq.clone()
1323
- uncond_attn = attn.clone()
1324
- uncond_seq[:, prompt_start:prompt_end] = pad_id
1325
- uncond_attn[:, prompt_start:prompt_end] = 0
 
 
 
 
 
1326
 
1327
  with torch.no_grad():
1328
  filled_tokens, _ = self.model.ti2ti_generate(
1329
- input_ids=seq.to(self.device),
1330
- uncond_input_ids=uncond_seq.to(self.device),
1331
- attention_mask=attn.to(self.device),
1332
- uncond_attention_mask=uncond_attn.to(self.device),
1333
  temperature=float(temperature),
1334
- timesteps=int(timesteps_image),
1335
  timesteps_text=int(timesteps_text),
1336
  timesteps_image=int(timesteps_image),
1337
  guidance_scale=float(guidance_scale),
@@ -1346,6 +1346,7 @@ class OmadaDemo:
1346
  if filled_tokens is None:
1347
  return None, "", "TI2TI generation failed."
1348
 
 
1349
  filled_tokens = torch.clamp(
1350
  filled_tokens,
1351
  min=0,
@@ -1358,7 +1359,7 @@ class OmadaDemo:
1358
  except Exception as exc:
1359
  return None, "", f"Failed to decode generated image: {exc}"
1360
 
1361
- text_slice = slice(text_start, min(text_end, filled_tokens.shape[1]))
1362
  text_block = filled_tokens[:, text_slice]
1363
  text_vocab = self.text_vocab_size
1364
  mask_id = int(self.mask_token_id)
 
1259
  return None, "", f"Failed to encode source image: {exc}"
1260
 
1261
  text_tokens = max(4, min(int(text_tokens), self.max_text_len))
1262
+
1263
+ # Build prompts using the same helper as training eval (ti2ti_prompt)
1264
+ placeholder_img = torch.full(
1265
+ (1, self.image_seq_len),
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1266
  self.mask_token_id,
1267
  dtype=torch.long,
1268
  device=self.device,
1269
  )
1270
+ labels_img_placeholder = torch.full_like(placeholder_img, int(self.uni_prompting.ignore_id))
1271
+ text_mask_bool = torch.ones(text_tokens, device=self.device, dtype=torch.bool)
1272
+
1273
+ input_ids, attention_mask, _ = self.uni_prompting.ti2ti_prompt(
1274
+ prompts=[instruction_clean],
1275
+ source_tokens=src_tokens,
1276
+ masked_target_tokens=placeholder_img,
1277
+ labels_img=labels_img_placeholder,
1278
+ target_texts=[""],
1279
+ target_mask_bools=[text_mask_bool],
1280
+ task_token="<|ti2ti|>",
1281
+ )
1282
+ uncond_ids, uncond_attn, _ = self.uni_prompting.ti2ti_prompt(
1283
+ prompts=[""],
1284
+ source_tokens=src_tokens,
1285
+ masked_target_tokens=placeholder_img,
1286
+ labels_img=labels_img_placeholder,
1287
+ target_texts=[""],
1288
+ target_mask_bools=[text_mask_bool],
1289
+ task_token="<|ti2ti|>",
1290
  )
1291
 
1292
+ input_ids = input_ids.to(self.device)
1293
+ attention_mask = attention_mask.to(self.device) if attention_mask is not None else None
1294
+ uncond_ids = uncond_ids.to(self.device)
1295
+ uncond_attn = uncond_attn.to(self.device) if uncond_attn is not None else None
1296
+
1297
+ # Locate spans before generation so we can force attention over desired text length
1298
+ seq_example = input_ids[0]
1299
+ soi_id = int(self.uni_prompting.sptids_dict['<|soi|>'][0].item())
1300
+ eoi_id = int(self.uni_prompting.sptids_dict['<|eoi|>'][0].item())
1301
+ pad_id = int(getattr(self.uni_prompting, "pad_id", 0))
1302
+ text_block_len = text_tokens
1303
+
1304
+ soi_positions = (seq_example == soi_id).nonzero(as_tuple=True)[0]
1305
+ eoi_positions = (seq_example == eoi_id).nonzero(as_tuple=True)[0]
1306
+ img_start = img_end = text_start = None
1307
+ if soi_positions.numel() >= 2:
1308
+ tgt_soi = int(soi_positions[1].item())
1309
+ eoi_after = [int(e.item()) for e in eoi_positions if int(e.item()) > tgt_soi]
1310
+ if eoi_after:
1311
+ tgt_eoi = eoi_after[0]
1312
+ img_start = tgt_soi + 1
1313
+ img_end = min(tgt_eoi, input_ids.shape[1])
1314
+ text_start = tgt_eoi + 1
1315
+ if img_start is None:
1316
+ non_pad = (seq_example != pad_id).nonzero(as_tuple=True)
1317
+ pad_offset = int(non_pad[0][0].item()) if len(non_pad) > 0 and non_pad[0].numel() > 0 else 0
1318
+ img_start = pad_offset + 1 + 1 + self.image_seq_len + 1 + self.uni_prompting.max_text_len + 1
1319
+ img_end = img_start + self.image_seq_len
1320
+ text_start = img_end + 1
1321
+ text_end = min(text_start + text_block_len, input_ids.shape[1])
1322
+ if attention_mask is not None:
1323
+ attention_mask[:, text_start:text_end] = 1
1324
+ if uncond_attn is not None:
1325
+ uncond_attn[:, text_start:text_end] = 1
1326
 
1327
  with torch.no_grad():
1328
  filled_tokens, _ = self.model.ti2ti_generate(
1329
+ input_ids=input_ids,
1330
+ uncond_input_ids=uncond_ids,
1331
+ attention_mask=attention_mask,
1332
+ uncond_attention_mask=uncond_attn,
1333
  temperature=float(temperature),
1334
+ timesteps=int(max(timesteps_image, timesteps_text)),
1335
  timesteps_text=int(timesteps_text),
1336
  timesteps_image=int(timesteps_image),
1337
  guidance_scale=float(guidance_scale),
 
1346
  if filled_tokens is None:
1347
  return None, "", "TI2TI generation failed."
1348
 
1349
+ # Locate spans like evaluate_ti2ti (target image/text blocks)
1350
  filled_tokens = torch.clamp(
1351
  filled_tokens,
1352
  min=0,
 
1359
  except Exception as exc:
1360
  return None, "", f"Failed to decode generated image: {exc}"
1361
 
1362
+ text_slice = slice(text_start, text_end)
1363
  text_block = filled_tokens[:, text_slice]
1364
  text_vocab = self.text_vocab_size
1365
  mask_id = int(self.mask_token_id)