Spaces:
Running
on
Zero
Running
on
Zero
Cleanup binaries before space push
Browse files
MMaDA/inference/common.py
CHANGED
|
@@ -57,14 +57,6 @@ def build_uni_prompting(cfg) -> Tuple[UniversalPrompting, AutoTokenizer]:
|
|
| 57 |
cond_dropout_prob=cfg.training.cond_dropout_prob,
|
| 58 |
use_reserved_token=True,
|
| 59 |
)
|
| 60 |
-
# Safety: if newer task tokens are missing (e.g., <|ti2ti|>, <|t2ti|>), inject them.
|
| 61 |
-
for tok in ("<|ti2ti|>", "<|t2ti|>"):
|
| 62 |
-
if tok not in uni_prompting.sptids_dict:
|
| 63 |
-
token_id = tokenizer.convert_tokens_to_ids(tok)
|
| 64 |
-
if token_id is None or token_id == tokenizer.unk_token_id:
|
| 65 |
-
tokenizer.add_special_tokens({"additional_special_tokens": [tok]})
|
| 66 |
-
token_id = tokenizer.convert_tokens_to_ids(tok)
|
| 67 |
-
uni_prompting.sptids_dict[tok] = torch.tensor([token_id])
|
| 68 |
return uni_prompting, tokenizer
|
| 69 |
|
| 70 |
|
|
|
|
| 57 |
cond_dropout_prob=cfg.training.cond_dropout_prob,
|
| 58 |
use_reserved_token=True,
|
| 59 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 60 |
return uni_prompting, tokenizer
|
| 61 |
|
| 62 |
|
MMaDA/inference/gradio_multimodal_demo_inst.py
CHANGED
|
@@ -1259,79 +1259,79 @@ class OmadaDemo:
|
|
| 1259 |
return None, "", f"Failed to encode source image: {exc}"
|
| 1260 |
|
| 1261 |
text_tokens = max(4, min(int(text_tokens), self.max_text_len))
|
| 1262 |
-
|
| 1263 |
-
|
| 1264 |
-
|
| 1265 |
-
|
| 1266 |
-
prompt_ids = [self.uni_prompting.text_tokenizer.bos_token_id] + prompt_ids
|
| 1267 |
-
prompt_ids = prompt_ids + [self.uni_prompting.text_tokenizer.eos_token_id]
|
| 1268 |
-
prompt_tensor = torch.tensor(prompt_ids, device=self.device, dtype=torch.long)
|
| 1269 |
-
|
| 1270 |
-
def _get_token(key: str):
|
| 1271 |
-
tok = self.uni_prompting.sptids_dict.get(key)
|
| 1272 |
-
if tok is None or tok.numel() == 0:
|
| 1273 |
-
return None
|
| 1274 |
-
return int(tok[0].item())
|
| 1275 |
-
|
| 1276 |
-
ti2ti_id = _get_token('<|ti2ti|>')
|
| 1277 |
-
soi_id = _get_token('<|soi|>')
|
| 1278 |
-
eoi_id = _get_token('<|eoi|>')
|
| 1279 |
-
if ti2ti_id is None or soi_id is None or eoi_id is None:
|
| 1280 |
-
return None, "", "TI2TI special tokens are missing in the tokenizer/config."
|
| 1281 |
-
pad_raw = getattr(self.uni_prompting, "pad_id", 0)
|
| 1282 |
-
pad_id = int(pad_raw if pad_raw is not None else 0)
|
| 1283 |
-
|
| 1284 |
-
img_placeholder = torch.full(
|
| 1285 |
-
(self.image_seq_len,),
|
| 1286 |
self.mask_token_id,
|
| 1287 |
dtype=torch.long,
|
| 1288 |
device=self.device,
|
| 1289 |
)
|
| 1290 |
-
|
| 1291 |
-
|
| 1292 |
-
|
| 1293 |
-
|
| 1294 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1295 |
)
|
| 1296 |
|
| 1297 |
-
|
| 1298 |
-
|
| 1299 |
-
|
| 1300 |
-
|
| 1301 |
-
|
| 1302 |
-
|
| 1303 |
-
|
| 1304 |
-
|
| 1305 |
-
|
| 1306 |
-
|
| 1307 |
-
|
| 1308 |
-
|
| 1309 |
-
|
| 1310 |
-
|
| 1311 |
-
|
| 1312 |
-
|
| 1313 |
-
|
| 1314 |
-
|
| 1315 |
-
|
| 1316 |
-
|
| 1317 |
-
|
| 1318 |
-
|
| 1319 |
-
|
| 1320 |
-
|
| 1321 |
-
|
| 1322 |
-
|
| 1323 |
-
|
| 1324 |
-
|
| 1325 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1326 |
|
| 1327 |
with torch.no_grad():
|
| 1328 |
filled_tokens, _ = self.model.ti2ti_generate(
|
| 1329 |
-
input_ids=
|
| 1330 |
-
uncond_input_ids=
|
| 1331 |
-
attention_mask=
|
| 1332 |
-
uncond_attention_mask=uncond_attn
|
| 1333 |
temperature=float(temperature),
|
| 1334 |
-
timesteps=int(timesteps_image),
|
| 1335 |
timesteps_text=int(timesteps_text),
|
| 1336 |
timesteps_image=int(timesteps_image),
|
| 1337 |
guidance_scale=float(guidance_scale),
|
|
@@ -1346,6 +1346,7 @@ class OmadaDemo:
|
|
| 1346 |
if filled_tokens is None:
|
| 1347 |
return None, "", "TI2TI generation failed."
|
| 1348 |
|
|
|
|
| 1349 |
filled_tokens = torch.clamp(
|
| 1350 |
filled_tokens,
|
| 1351 |
min=0,
|
|
@@ -1358,7 +1359,7 @@ class OmadaDemo:
|
|
| 1358 |
except Exception as exc:
|
| 1359 |
return None, "", f"Failed to decode generated image: {exc}"
|
| 1360 |
|
| 1361 |
-
text_slice = slice(text_start,
|
| 1362 |
text_block = filled_tokens[:, text_slice]
|
| 1363 |
text_vocab = self.text_vocab_size
|
| 1364 |
mask_id = int(self.mask_token_id)
|
|
|
|
| 1259 |
return None, "", f"Failed to encode source image: {exc}"
|
| 1260 |
|
| 1261 |
text_tokens = max(4, min(int(text_tokens), self.max_text_len))
|
| 1262 |
+
|
| 1263 |
+
# Build prompts using the same helper as training eval (ti2ti_prompt)
|
| 1264 |
+
placeholder_img = torch.full(
|
| 1265 |
+
(1, self.image_seq_len),
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1266 |
self.mask_token_id,
|
| 1267 |
dtype=torch.long,
|
| 1268 |
device=self.device,
|
| 1269 |
)
|
| 1270 |
+
labels_img_placeholder = torch.full_like(placeholder_img, int(self.uni_prompting.ignore_id))
|
| 1271 |
+
text_mask_bool = torch.ones(text_tokens, device=self.device, dtype=torch.bool)
|
| 1272 |
+
|
| 1273 |
+
input_ids, attention_mask, _ = self.uni_prompting.ti2ti_prompt(
|
| 1274 |
+
prompts=[instruction_clean],
|
| 1275 |
+
source_tokens=src_tokens,
|
| 1276 |
+
masked_target_tokens=placeholder_img,
|
| 1277 |
+
labels_img=labels_img_placeholder,
|
| 1278 |
+
target_texts=[""],
|
| 1279 |
+
target_mask_bools=[text_mask_bool],
|
| 1280 |
+
task_token="<|ti2ti|>",
|
| 1281 |
+
)
|
| 1282 |
+
uncond_ids, uncond_attn, _ = self.uni_prompting.ti2ti_prompt(
|
| 1283 |
+
prompts=[""],
|
| 1284 |
+
source_tokens=src_tokens,
|
| 1285 |
+
masked_target_tokens=placeholder_img,
|
| 1286 |
+
labels_img=labels_img_placeholder,
|
| 1287 |
+
target_texts=[""],
|
| 1288 |
+
target_mask_bools=[text_mask_bool],
|
| 1289 |
+
task_token="<|ti2ti|>",
|
| 1290 |
)
|
| 1291 |
|
| 1292 |
+
input_ids = input_ids.to(self.device)
|
| 1293 |
+
attention_mask = attention_mask.to(self.device) if attention_mask is not None else None
|
| 1294 |
+
uncond_ids = uncond_ids.to(self.device)
|
| 1295 |
+
uncond_attn = uncond_attn.to(self.device) if uncond_attn is not None else None
|
| 1296 |
+
|
| 1297 |
+
# Locate spans before generation so we can force attention over desired text length
|
| 1298 |
+
seq_example = input_ids[0]
|
| 1299 |
+
soi_id = int(self.uni_prompting.sptids_dict['<|soi|>'][0].item())
|
| 1300 |
+
eoi_id = int(self.uni_prompting.sptids_dict['<|eoi|>'][0].item())
|
| 1301 |
+
pad_id = int(getattr(self.uni_prompting, "pad_id", 0))
|
| 1302 |
+
text_block_len = text_tokens
|
| 1303 |
+
|
| 1304 |
+
soi_positions = (seq_example == soi_id).nonzero(as_tuple=True)[0]
|
| 1305 |
+
eoi_positions = (seq_example == eoi_id).nonzero(as_tuple=True)[0]
|
| 1306 |
+
img_start = img_end = text_start = None
|
| 1307 |
+
if soi_positions.numel() >= 2:
|
| 1308 |
+
tgt_soi = int(soi_positions[1].item())
|
| 1309 |
+
eoi_after = [int(e.item()) for e in eoi_positions if int(e.item()) > tgt_soi]
|
| 1310 |
+
if eoi_after:
|
| 1311 |
+
tgt_eoi = eoi_after[0]
|
| 1312 |
+
img_start = tgt_soi + 1
|
| 1313 |
+
img_end = min(tgt_eoi, input_ids.shape[1])
|
| 1314 |
+
text_start = tgt_eoi + 1
|
| 1315 |
+
if img_start is None:
|
| 1316 |
+
non_pad = (seq_example != pad_id).nonzero(as_tuple=True)
|
| 1317 |
+
pad_offset = int(non_pad[0][0].item()) if len(non_pad) > 0 and non_pad[0].numel() > 0 else 0
|
| 1318 |
+
img_start = pad_offset + 1 + 1 + self.image_seq_len + 1 + self.uni_prompting.max_text_len + 1
|
| 1319 |
+
img_end = img_start + self.image_seq_len
|
| 1320 |
+
text_start = img_end + 1
|
| 1321 |
+
text_end = min(text_start + text_block_len, input_ids.shape[1])
|
| 1322 |
+
if attention_mask is not None:
|
| 1323 |
+
attention_mask[:, text_start:text_end] = 1
|
| 1324 |
+
if uncond_attn is not None:
|
| 1325 |
+
uncond_attn[:, text_start:text_end] = 1
|
| 1326 |
|
| 1327 |
with torch.no_grad():
|
| 1328 |
filled_tokens, _ = self.model.ti2ti_generate(
|
| 1329 |
+
input_ids=input_ids,
|
| 1330 |
+
uncond_input_ids=uncond_ids,
|
| 1331 |
+
attention_mask=attention_mask,
|
| 1332 |
+
uncond_attention_mask=uncond_attn,
|
| 1333 |
temperature=float(temperature),
|
| 1334 |
+
timesteps=int(max(timesteps_image, timesteps_text)),
|
| 1335 |
timesteps_text=int(timesteps_text),
|
| 1336 |
timesteps_image=int(timesteps_image),
|
| 1337 |
guidance_scale=float(guidance_scale),
|
|
|
|
| 1346 |
if filled_tokens is None:
|
| 1347 |
return None, "", "TI2TI generation failed."
|
| 1348 |
|
| 1349 |
+
# Locate spans like evaluate_ti2ti (target image/text blocks)
|
| 1350 |
filled_tokens = torch.clamp(
|
| 1351 |
filled_tokens,
|
| 1352 |
min=0,
|
|
|
|
| 1359 |
except Exception as exc:
|
| 1360 |
return None, "", f"Failed to decode generated image: {exc}"
|
| 1361 |
|
| 1362 |
+
text_slice = slice(text_start, text_end)
|
| 1363 |
text_block = filled_tokens[:, text_slice]
|
| 1364 |
text_vocab = self.text_vocab_size
|
| 1365 |
mask_id = int(self.mask_token_id)
|