import torch from torch import nn from transformers import PreTrainedModel, PretrainedConfig from safetensors.torch import load_file # CLIP from .modeling_clipPT import CLIPVisionTransformer from transformers import CLIPImageProcessor from transformers import AutoTokenizer # Qwen from .modeling_qwen2 import Qwen2Model # Timer from .modeling_timer import TimerForPrediction class MulTiCastTimerConfig(PretrainedConfig): def __init__( self, forecasting_length = None, vision_model_name = None, text_model_name = None, vision_model_prompt_len = None, text_model_prompt_len = None, timer_prompt_len = None, **kwargs ): super().__init__(**kwargs) self.forecasting_length = forecasting_length self.vision_model_name = vision_model_name self.text_model_name = text_model_name self.vision_model_prompt_len = vision_model_prompt_len if vision_model_prompt_len is not None else 10 self.text_model_prompt_len = text_model_prompt_len if text_model_prompt_len is not None else 4 self.timer_prompt_len = timer_prompt_len if timer_prompt_len is not None else 4 class MulTiCastTimerModel(PreTrainedModel): config_class = MulTiCastTimerConfig def __init__(self, config): super().__init__(config) self.config = config # Vision Model if config.vision_model_name is None: pass elif config.vision_model_name == 'CLIP': from transformers import AutoModel vision_model = AutoModel.from_pretrained("openai/clip-vit-base-patch32").vision_model state_dict = vision_model.state_dict() state_dict = {k: v.to(torch.bfloat16) for k, v in state_dict.items()} self.vision_model = CLIPVisionTransformer(vision_model.config, config.vision_model_prompt_len) self.vision_model.load_state_dict(state_dict, strict=False) self.processor = CLIPImageProcessor() for name, param in self.vision_model.named_parameters(): # Freeze layers other than prompts if "encoder.prompts" in name: param.requires_grad = True else: param.requires_grad = False else: pass # Text Model if config.text_model_name is None: pass elif config.text_model_name == 'Qwen': self.tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2-1.5B-Instruct") from transformers import AutoModelForCausalLM text_model = AutoModelForCausalLM.from_pretrained( "Qwen/Qwen2-1.5B-Instruct", torch_dtype=torch.bfloat16, device_map="cpu", attn_implementation="sdpa" ).model state_dict = text_model.state_dict() self.text_model = Qwen2Model(text_model.config, config.text_model_prompt_len) self.text_model.load_state_dict(state_dict, strict=False) for name, param in self.text_model.named_parameters(): # Freeze layers other than prompts if "prompts" in name: param.requires_grad = True else: param.requires_grad = False else: pass # Timer from transformers import AutoModelForCausalLM timer = AutoModelForCausalLM.from_pretrained('thuml/timer-base-84m', trust_remote_code=True) state_dict = timer.state_dict() self.timer = TimerForPrediction(timer.config, config.timer_prompt_len) self.timer.load_state_dict(state_dict, strict=False) for name, param in self.timer.named_parameters(): # Freeze layers other than prompts if "model.prompts" in name: param.requires_grad = True else: param.requires_grad = False # Vision Interaction Layer if config.vision_model_name is None: pass else: self.vision_interaction_layer = nn.Linear(self.vision_model.config.hidden_size, self.timer.config.hidden_size) # Text Interaction Layer if config.text_model_name is None: pass else: self.text_interaction_layer = nn.Linear(self.text_model.config.hidden_size, self.timer.config.hidden_size) def predict(self, input_ids = None, images = None, texts = None): images = self.processor.preprocess(images)['pixel_values'][0] images = torch.tensor(images) images = images.unsqueeze(0) if self.config.vision_model_name is None and images is None: vision_embedding = None else: vision_output = self.vision_model(images, output_attentions=True) vision_attentions = vision_output.attentions vision_embedding = vision_output.pooler_output vision_embedding = self.vision_interaction_layer(vision_embedding) if self.config.text_model_name is None and all(x is None for x in texts): text_embedding = None else: tokenized_texts = self.tokenizer(texts, return_tensors="pt") text_embedding = self.text_model(**tokenized_texts) text_embedding = text_embedding.last_hidden_state[:, 0 , :] text_embedding = self.text_interaction_layer(text_embedding) out = self.timer(input_ids=input_ids, vision_embedding=vision_embedding, text_embedding=text_embedding) return { "logits": out.logits, "vision_attentions": vision_attentions, "time_series_attentions": out.attentions } def forward(self, input_ids = None, images = None, texts = None, labels = None): if self.config.vision_model_name is None and images is None: vision_embedding = None else: vision_embedding = self.vision_model(images) vision_embedding = vision_embedding.pooler_output vision_embedding = self.vision_interaction_layer(vision_embedding) if self.config.text_model_name is None and all(x is None for x in texts): text_embedding = None else: tokenized_texts = self.tokenizer(texts, return_tensors="pt") text_embedding = self.text_model(**tokenized_texts) text_embedding = text_embedding.last_hidden_state[:, 0 , :] text_embedding = self.text_interaction_layer(text_embedding) out = self.timer(input_ids=input_ids, vision_embedding=vision_embedding, text_embedding=text_embedding) out = out["logits"] if labels is not None: if self.config.forecasting_length == out.shape[-1]: loss = torch.mean(torch.square(out-labels)) # MSE else: # pretrained Timer has 96 forecasting length. This is in case of shorter forecasting length. Forecasting length larger than 96 will occure an error. loss = torch.mean(torch.square(out[:, :self.config.forecasting_length]-labels)) else: loss = None return { "loss": loss, "logits": out } @classmethod def from_pretrained(cls, pretrained_model_name_or_path, **kwargs): from transformers.utils import cached_file config = MulTiCastTimerConfig.from_pretrained(pretrained_model_name_or_path) model = MulTiCastTimerModel(config) resolved_file = cached_file(pretrained_model_name_or_path, "model.safetensors") state_dict = load_file(resolved_file) model.load_state_dict(state_dict, strict=False) return model