Spaces:
Running
on
Zero
Running
on
Zero
| import os | |
| import itertools | |
| import json | |
| from pathlib import Path | |
| from typing import Dict, Iterable, List, Tuple | |
| import torch | |
| import wandb | |
| from omegaconf import OmegaConf | |
| from transformers import AutoTokenizer | |
| from models import MAGVITv2, OMadaModelLM | |
| from models.modeling_emova_speech_tokenizer import EMOVASpeechTokenizer | |
| from training.prompting_utils import UniversalPrompting | |
| def load_train_config(path: str): | |
| cfg = OmegaConf.load(path) | |
| return cfg | |
| def get_vq_model_image(cfg, device): | |
| vq_cfg = cfg.model.vq_model_image | |
| if getattr(vq_cfg, "pretrained_model_path", None): | |
| model = MAGVITv2().to(device) | |
| state_dict = torch.load(vq_cfg.pretrained_model_path)["model"] | |
| model.load_state_dict(state_dict) | |
| return model.eval() | |
| else: | |
| return MAGVITv2.from_pretrained(vq_cfg.vq_model_name).to(device).eval() | |
| def get_vq_model_audio(cfg, device): | |
| vq_cfg = cfg.model.vq_model_audio | |
| # Always EMOVA for now | |
| model = EMOVASpeechTokenizer.from_pretrained(vq_cfg.vq_model_name) | |
| model = model.to(device) | |
| model.eval() | |
| return model | |
| def build_uni_prompting(cfg) -> Tuple[UniversalPrompting, AutoTokenizer]: | |
| tokenizer = AutoTokenizer.from_pretrained(cfg.model.omada.tokenizer_path, padding_side="left") | |
| uni_prompting = UniversalPrompting( | |
| tokenizer, | |
| max_text_len=cfg.dataset.preprocessing.max_seq_length, | |
| max_audio_len=cfg.dataset.preprocessing.max_aud_length, | |
| special_tokens=( | |
| "<|soi|>", "<|eoi|>", "<|sov|>", "<|eov|>", "<|t2i|>", | |
| "<|mmu|>", "<|t2v|>", "<|v2v|>", "<|lvg|>", | |
| "<|i2i|>", "<|v2s|>", "<|s2s|>", | |
| "<|v2t|>", "<|s2t|>", "<|t2s|>", "<|soa|>", "<|eoa|>", | |
| "<|ti2ti|>", "<|t2ti|>", | |
| ), | |
| ignore_id=-100, | |
| cond_dropout_prob=cfg.training.cond_dropout_prob, | |
| use_reserved_token=True, | |
| ) | |
| return uni_prompting, tokenizer | |
| def load_omada_from_checkpoint(ckpt_unwrapped_dir: str, device, torch_dtype=torch.bfloat16) -> OMadaModelLM: | |
| """Load OMada model weights from an `unwrapped_model` directory. | |
| The helper used to rely on a hard-coded config path which broke when | |
| evaluating checkpoints from other training steps. We now detect the | |
| config.json co-located with the weights so any checkpoint exported by the | |
| trainer can be used directly. | |
| """ | |
| ckpt_path = Path(ckpt_unwrapped_dir) | |
| if not ckpt_path.is_dir(): | |
| raise FileNotFoundError(f"Expected an 'unwrapped_model' directory, got {ckpt_unwrapped_dir}") | |
| config_path = ckpt_path / "config.json" | |
| config_arg = str(config_path) if config_path.exists() else None | |
| model = OMadaModelLM.from_pretrained( | |
| ckpt_unwrapped_dir, | |
| torch_dtype=torch_dtype, | |
| config=config_arg, | |
| trust_remote_code=True, | |
| ).to(device) | |
| model.eval() | |
| return model | |
| def list_checkpoints(ckpt_root: str) -> List[str]: | |
| """Return a sorted list of checkpoint 'unwrapped_model' dirs under a training output dir or a direct ckpt dir. | |
| Accepts either: | |
| - A path that already ends with 'unwrapped_model' | |
| - A path to 'checkpoint-XXXX' (we append 'unwrapped_model') | |
| - A path to the experiment output dir that contains many 'checkpoint-*' | |
| """ | |
| p = Path(ckpt_root) | |
| if p.name == "unwrapped_model" and p.is_dir(): | |
| return [str(p)] | |
| if p.name.startswith("checkpoint-") and p.is_dir(): | |
| inner = p / "unwrapped_model" | |
| return [str(inner)] if inner.is_dir() else [] | |
| # otherwise, collect children checkpoints | |
| outs = [] | |
| for child in p.iterdir(): | |
| if child.is_dir() and child.name.startswith("checkpoint-"): | |
| inner = child / "unwrapped_model" | |
| if inner.is_dir(): | |
| outs.append(str(inner)) | |
| # sort by numeric step if possible | |
| def step_key(s: str): | |
| try: | |
| return int(Path(s).parent.name.split("-")[-1]) | |
| except Exception: | |
| return -1 | |
| outs.sort(key=step_key) | |
| return outs | |
| def grid_dict(product_space: Dict[str, Iterable]) -> List[Dict]: | |
| """Expand a dict of lists to a list of dict combinations. | |
| Example: {a:[1,2], b:["x"]} -> [{a:1,b:"x"},{a:2,b:"x"}] | |
| """ | |
| keys = list(product_space.keys()) | |
| values = [list(v if isinstance(v, (list, tuple)) else [v]) for v in product_space.values()] | |
| combos = [] | |
| for vals in itertools.product(*values): | |
| combos.append({k: v for k, v in zip(keys, vals)}) | |
| return combos | |
| def init_wandb(infer_cfg: Dict, task: str, ckpt_path: str, hparams: Dict): | |
| wcfg = infer_cfg.get("wandb", {}) | |
| project = wcfg.get("project", f"omada-inference-{task}") | |
| entity = wcfg.get("entity") | |
| group = wcfg.get("group", f"{task}") | |
| name_prefix = wcfg.get("name_prefix", f"{task}") | |
| step_str = Path(ckpt_path).parent.name | |
| run_name = f"{name_prefix}-{step_str}-" + ",".join([f"{k}={v}" for k, v in hparams.items()]) | |
| tags = wcfg.get("tags", []) | |
| wandb.init(project=project, entity=entity, group=group, name=run_name, tags=tags, config={ | |
| "task": task, | |
| "checkpoint": ckpt_path, | |
| "hparams": hparams, | |
| }) | |
| def safe_log_table(name: str, columns: List[str], rows: List[List]): | |
| try: | |
| table = wandb.Table(columns=columns) | |
| for r in rows: | |
| table.add_data(*r) | |
| wandb.log({name: table}) | |
| except Exception: | |
| pass | |