| import torch |
| from transformers import AutoConfig |
|
|
|
|
| def extend_list(data_list, n, min_n): |
| if min_n == 0: |
| return [] |
| while len(data_list) < n: |
| data_list.extend(data_list[:n - len(data_list)]) |
| return data_list |
|
|
|
|
| def find_prefix(input_ids, prefix): |
| """ |
| input_ids: [B, N1], no start token |
| prefix: [N2, ], no start token |
| """ |
| len_prefix = prefix.shape[0] |
| |
| input_ids_unfold = input_ids.unfold(1, len_prefix, 1) |
| |
| matches = (input_ids_unfold == prefix).all(dim=2) |
| |
| matches_int = matches.type(torch.int64) |
| |
| indices = torch.where( |
| matches.any(dim=1), |
| matches_int.argmax(dim=1), |
| torch.tensor(-1, dtype=torch.int64), |
| ) |
| assert (indices >= 0).all(), "Some inputs do not contain prefix" |
| return indices |
|
|
|
|
| def auto_upgrade(config): |
| cfg = AutoConfig.from_pretrained(config) |
| if "mplug_owl2" in config and "mplug_owl2" not in cfg.model_type: |
| assert cfg.model_type == "mplug_owl2" |
| print( |
| "You are using newer LLaVA code base, while the checkpoint of v0 is from older code base." |
| ) |
| print( |
| "You must upgrade the checkpoint to the new code base (this can be done automatically)." |
| ) |
| confirm = input("Please confirm that you want to upgrade the checkpoint. [Y/N]") |
| if confirm.lower() in ["y", "yes"]: |
| print("Upgrading checkpoint...") |
| assert len(cfg.architectures) == 1 |
| setattr(cfg.__class__, "model_type", "mplug_owl2") |
| cfg.architectures[0] = "LlavaLlamaForCausalLM" |
| cfg.save_pretrained(config) |
| print("Checkpoint upgraded.") |
| else: |
| print("Checkpoint upgrade aborted.") |
| exit(1) |
|
|