File size: 2,914 Bytes
5e8482a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
# fuseclip_hub.py  (keep the rest of your code unchanged)
import inspect
import json
import shutil
from pathlib import Path

import numpy as np
import torch
from huggingface_hub import PyTorchModelHubMixin

from fuse_clip.fuse_clip_arch import FuseCLIP
from open_clip import get_input_dtype, SimpleTokenizer


class FuseLIP(FuseCLIP, PyTorchModelHubMixin):
    """FuseLIP with save_pretrained / from_pretrained / push_to_hub."""

    # ---------- save ----------
    def _save_pretrained(self, save_directory: Path, **kwargs):
        save_directory = Path(save_directory)
        save_directory.mkdir(parents=True, exist_ok=True)

        torch.save(self.state_dict(), save_directory / "pytorch_model.bin")
        (save_directory / "config.json").write_text(
            json.dumps(self.get_config(), indent=2)
        )
        # copy TiTok VQ-VAE weights so offline loading works
        # shutil.copy(
        #     self.image_tokenizer.tokenizer_path,
        #     save_directory / "titok_image_tokenizer.pt"
        # )

        # publish fuse_clip_hub.py
        source_path = Path(inspect.getfile(FuseLIP))  # absolute path of this file
        shutil.copy(source_path, save_directory / "fuse_clip_hub.py")

    # ---------- load ----------
    @classmethod
    def _from_pretrained(cls, save_directory: Path, **kwargs):

        cfg = json.loads(Path(save_directory, "config.json").read_text())

        tokenizer = SimpleTokenizer(context_length=cfg["context_length"])
        tokenizer.pad_token_id = 0

        if cfg["mlm_probability"] > 0:
            MASK_TOKEN = "[MASK]"
            if MASK_TOKEN not in tokenizer.encoder:
                # Assign a new token ID
                mask_token_id = max(tokenizer.encoder.values()) + 1  # Get a new unique ID

                # Add to tokenizer's vocabulary
                tokenizer.encoder[MASK_TOKEN] = mask_token_id
                tokenizer.decoder[mask_token_id] = MASK_TOKEN

                tokenizer.all_special_ids.append(mask_token_id)
                tokenizer.mask_token = mask_token_id
                tokenizer.vocab_size += 1

                print(f"Added `[MASK]` token with ID {mask_token_id}")
            else:
                mask_token_id = tokenizer.encoder[MASK_TOKEN]
                print(f"`[MASK]` token already exists with ID {mask_token_id}")


        cfg["image_tokenizer_path"] = cfg["image_tokenizer"]
        cfg["init_logit_scale"] = np.log(10)
        cfg["init_logit_bias"] = -10
        cfg["input_dtype"] = get_input_dtype("fp32")
        del cfg["text_config"]
        del cfg["image_tokenizer"]
        del cfg["context_length"]

        model = cls(**cfg, **kwargs)  # device / dtype can be injected via kwargs
        state = torch.load(
            Path(save_directory, "pytorch_model.bin"),
            map_location="cpu"
        )
        model.load_state_dict(state, strict=True)
        return model