File size: 7,887 Bytes
60cbb5b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
import torch
from torch import nn
from transformers import PreTrainedModel, PretrainedConfig
from safetensors.torch import load_file

# CLIP
from .modeling_clipPT import CLIPVisionTransformer
from transformers import CLIPImageProcessor

from transformers import AutoTokenizer
# Qwen
from .modeling_qwen2 import Qwen2Model

# Timer
from .modeling_timer import TimerForPrediction

class MulTiCastTimerConfig(PretrainedConfig):
    def __init__(

        self,

        forecasting_length = None,

        vision_model_name = None,

        text_model_name = None,

        vision_model_prompt_len = None,

        text_model_prompt_len = None,

        timer_prompt_len = None,

        **kwargs

    ):
        super().__init__(**kwargs)
        self.forecasting_length = forecasting_length
        self.vision_model_name = vision_model_name
        self.text_model_name = text_model_name
        
        self.vision_model_prompt_len = vision_model_prompt_len if vision_model_prompt_len is not None else 10
        self.text_model_prompt_len = text_model_prompt_len if text_model_prompt_len is not None else 4
        self.timer_prompt_len = timer_prompt_len if timer_prompt_len is not None else 4

class MulTiCastTimerModel(PreTrainedModel):
    
    config_class = MulTiCastTimerConfig

    def __init__(self, config):
        super().__init__(config)
        self.config = config

        # Vision Model
        if config.vision_model_name is None:
            pass
        elif config.vision_model_name == 'CLIP':
            from transformers import AutoModel
            vision_model = AutoModel.from_pretrained("openai/clip-vit-base-patch32").vision_model
            state_dict = vision_model.state_dict()
            state_dict = {k: v.to(torch.bfloat16) for k, v in state_dict.items()}
            self.vision_model = CLIPVisionTransformer(vision_model.config, config.vision_model_prompt_len)
            self.vision_model.load_state_dict(state_dict, strict=False)
            self.processor = CLIPImageProcessor()
            for name, param in self.vision_model.named_parameters(): # Freeze layers other than prompts
                if "encoder.prompts" in name:
                    param.requires_grad = True
                else:
                    param.requires_grad = False
        else:
            pass
        
        # Text Model
        if config.text_model_name is None:
            pass
        elif config.text_model_name == 'Qwen':
            self.tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2-1.5B-Instruct")
            from transformers import AutoModelForCausalLM
            text_model = AutoModelForCausalLM.from_pretrained(
                "Qwen/Qwen2-1.5B-Instruct",
                torch_dtype=torch.bfloat16,
                device_map="cpu",
                attn_implementation="sdpa"
            ).model
            state_dict = text_model.state_dict()
            self.text_model = Qwen2Model(text_model.config, config.text_model_prompt_len)
            self.text_model.load_state_dict(state_dict, strict=False)
            for name, param in self.text_model.named_parameters(): # Freeze layers other than prompts
                if "prompts" in name:
                    param.requires_grad = True
                else:
                    param.requires_grad = False
        else:
            pass
            
        # Timer
        from transformers import AutoModelForCausalLM
        timer = AutoModelForCausalLM.from_pretrained('thuml/timer-base-84m', trust_remote_code=True)
        state_dict = timer.state_dict()
        self.timer = TimerForPrediction(timer.config, config.timer_prompt_len)
        self.timer.load_state_dict(state_dict, strict=False)
        for name, param in self.timer.named_parameters(): # Freeze layers other than prompts
            if "model.prompts" in name:
                param.requires_grad = True
            else:
                param.requires_grad = False
        
        # Vision Interaction Layer
        if config.vision_model_name is None:
            pass
        else:
            self.vision_interaction_layer = nn.Linear(self.vision_model.config.hidden_size, self.timer.config.hidden_size)

        # Text Interaction Layer
        if config.text_model_name is None:
            pass
        else:
            self.text_interaction_layer = nn.Linear(self.text_model.config.hidden_size, self.timer.config.hidden_size)
    
    def predict(self, input_ids = None, images = None, texts = None):
        images = self.processor.preprocess(images)['pixel_values'][0]
        images = torch.tensor(images)
        images = images.unsqueeze(0)

        if self.config.vision_model_name is None and images is None:
            vision_embedding = None
        else:
            vision_output = self.vision_model(images, output_attentions=True)
            vision_attentions = vision_output.attentions
            vision_embedding = vision_output.pooler_output
            vision_embedding = self.vision_interaction_layer(vision_embedding)

        if self.config.text_model_name is None and all(x is None for x in texts):
            text_embedding = None
        else:
            tokenized_texts = self.tokenizer(texts, return_tensors="pt")
            text_embedding = self.text_model(**tokenized_texts)
            text_embedding = text_embedding.last_hidden_state[:, 0 , :]
            text_embedding = self.text_interaction_layer(text_embedding)

        out = self.timer(input_ids=input_ids, vision_embedding=vision_embedding, text_embedding=text_embedding)

        return {
            "logits": out.logits,
            "vision_attentions": vision_attentions,
            "time_series_attentions": out.attentions
        }
    
    def forward(self, input_ids = None, images = None, texts = None, labels = None):
        if self.config.vision_model_name is None and images is None:
            vision_embedding = None
        else:
            vision_embedding = self.vision_model(images)
            vision_embedding = vision_embedding.pooler_output
            vision_embedding = self.vision_interaction_layer(vision_embedding)

        if self.config.text_model_name is None and all(x is None for x in texts):
            text_embedding = None
        else:
            tokenized_texts = self.tokenizer(texts, return_tensors="pt")
            text_embedding = self.text_model(**tokenized_texts)
            text_embedding = text_embedding.last_hidden_state[:, 0 , :]
            text_embedding = self.text_interaction_layer(text_embedding)

        out = self.timer(input_ids=input_ids, vision_embedding=vision_embedding, text_embedding=text_embedding)
        out = out["logits"]

        if labels is not None:
            if self.config.forecasting_length == out.shape[-1]:
                loss = torch.mean(torch.square(out-labels)) # MSE
            else: # pretrained Timer has 96 forecasting length. This is in case of shorter forecasting length. Forecasting length larger than 96 will occure an error.
                loss = torch.mean(torch.square(out[:, :self.config.forecasting_length]-labels))
        else:
            loss = None

        return {
            "loss": loss,
            "logits": out
        }
    
    @classmethod
    def from_pretrained(cls, pretrained_model_name_or_path, **kwargs):
        from transformers.utils import cached_file
        config = MulTiCastTimerConfig.from_pretrained(pretrained_model_name_or_path)
        model = MulTiCastTimerModel(config)
        resolved_file = cached_file(pretrained_model_name_or_path, "model.safetensors")
        state_dict = load_file(resolved_file)
        model.load_state_dict(state_dict, strict=False)

        return model