MulTiCast / modeling_multicasttimer.py
adnlp's picture
Upload 10 files
60cbb5b verified
import torch
from torch import nn
from transformers import PreTrainedModel, PretrainedConfig
from safetensors.torch import load_file
# CLIP
from .modeling_clipPT import CLIPVisionTransformer
from transformers import CLIPImageProcessor
from transformers import AutoTokenizer
# Qwen
from .modeling_qwen2 import Qwen2Model
# Timer
from .modeling_timer import TimerForPrediction
class MulTiCastTimerConfig(PretrainedConfig):
def __init__(
self,
forecasting_length = None,
vision_model_name = None,
text_model_name = None,
vision_model_prompt_len = None,
text_model_prompt_len = None,
timer_prompt_len = None,
**kwargs
):
super().__init__(**kwargs)
self.forecasting_length = forecasting_length
self.vision_model_name = vision_model_name
self.text_model_name = text_model_name
self.vision_model_prompt_len = vision_model_prompt_len if vision_model_prompt_len is not None else 10
self.text_model_prompt_len = text_model_prompt_len if text_model_prompt_len is not None else 4
self.timer_prompt_len = timer_prompt_len if timer_prompt_len is not None else 4
class MulTiCastTimerModel(PreTrainedModel):
config_class = MulTiCastTimerConfig
def __init__(self, config):
super().__init__(config)
self.config = config
# Vision Model
if config.vision_model_name is None:
pass
elif config.vision_model_name == 'CLIP':
from transformers import AutoModel
vision_model = AutoModel.from_pretrained("openai/clip-vit-base-patch32").vision_model
state_dict = vision_model.state_dict()
state_dict = {k: v.to(torch.bfloat16) for k, v in state_dict.items()}
self.vision_model = CLIPVisionTransformer(vision_model.config, config.vision_model_prompt_len)
self.vision_model.load_state_dict(state_dict, strict=False)
self.processor = CLIPImageProcessor()
for name, param in self.vision_model.named_parameters(): # Freeze layers other than prompts
if "encoder.prompts" in name:
param.requires_grad = True
else:
param.requires_grad = False
else:
pass
# Text Model
if config.text_model_name is None:
pass
elif config.text_model_name == 'Qwen':
self.tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2-1.5B-Instruct")
from transformers import AutoModelForCausalLM
text_model = AutoModelForCausalLM.from_pretrained(
"Qwen/Qwen2-1.5B-Instruct",
torch_dtype=torch.bfloat16,
device_map="cpu",
attn_implementation="sdpa"
).model
state_dict = text_model.state_dict()
self.text_model = Qwen2Model(text_model.config, config.text_model_prompt_len)
self.text_model.load_state_dict(state_dict, strict=False)
for name, param in self.text_model.named_parameters(): # Freeze layers other than prompts
if "prompts" in name:
param.requires_grad = True
else:
param.requires_grad = False
else:
pass
# Timer
from transformers import AutoModelForCausalLM
timer = AutoModelForCausalLM.from_pretrained('thuml/timer-base-84m', trust_remote_code=True)
state_dict = timer.state_dict()
self.timer = TimerForPrediction(timer.config, config.timer_prompt_len)
self.timer.load_state_dict(state_dict, strict=False)
for name, param in self.timer.named_parameters(): # Freeze layers other than prompts
if "model.prompts" in name:
param.requires_grad = True
else:
param.requires_grad = False
# Vision Interaction Layer
if config.vision_model_name is None:
pass
else:
self.vision_interaction_layer = nn.Linear(self.vision_model.config.hidden_size, self.timer.config.hidden_size)
# Text Interaction Layer
if config.text_model_name is None:
pass
else:
self.text_interaction_layer = nn.Linear(self.text_model.config.hidden_size, self.timer.config.hidden_size)
def predict(self, input_ids = None, images = None, texts = None):
images = self.processor.preprocess(images)['pixel_values'][0]
images = torch.tensor(images)
images = images.unsqueeze(0)
if self.config.vision_model_name is None and images is None:
vision_embedding = None
else:
vision_output = self.vision_model(images, output_attentions=True)
vision_attentions = vision_output.attentions
vision_embedding = vision_output.pooler_output
vision_embedding = self.vision_interaction_layer(vision_embedding)
if self.config.text_model_name is None and all(x is None for x in texts):
text_embedding = None
else:
tokenized_texts = self.tokenizer(texts, return_tensors="pt")
text_embedding = self.text_model(**tokenized_texts)
text_embedding = text_embedding.last_hidden_state[:, 0 , :]
text_embedding = self.text_interaction_layer(text_embedding)
out = self.timer(input_ids=input_ids, vision_embedding=vision_embedding, text_embedding=text_embedding)
return {
"logits": out.logits,
"vision_attentions": vision_attentions,
"time_series_attentions": out.attentions
}
def forward(self, input_ids = None, images = None, texts = None, labels = None):
if self.config.vision_model_name is None and images is None:
vision_embedding = None
else:
vision_embedding = self.vision_model(images)
vision_embedding = vision_embedding.pooler_output
vision_embedding = self.vision_interaction_layer(vision_embedding)
if self.config.text_model_name is None and all(x is None for x in texts):
text_embedding = None
else:
tokenized_texts = self.tokenizer(texts, return_tensors="pt")
text_embedding = self.text_model(**tokenized_texts)
text_embedding = text_embedding.last_hidden_state[:, 0 , :]
text_embedding = self.text_interaction_layer(text_embedding)
out = self.timer(input_ids=input_ids, vision_embedding=vision_embedding, text_embedding=text_embedding)
out = out["logits"]
if labels is not None:
if self.config.forecasting_length == out.shape[-1]:
loss = torch.mean(torch.square(out-labels)) # MSE
else: # pretrained Timer has 96 forecasting length. This is in case of shorter forecasting length. Forecasting length larger than 96 will occure an error.
loss = torch.mean(torch.square(out[:, :self.config.forecasting_length]-labels))
else:
loss = None
return {
"loss": loss,
"logits": out
}
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path, **kwargs):
from transformers.utils import cached_file
config = MulTiCastTimerConfig.from_pretrained(pretrained_model_name_or_path)
model = MulTiCastTimerModel(config)
resolved_file = cached_file(pretrained_model_name_or_path, "model.safetensors")
state_dict = load_file(resolved_file)
model.load_state_dict(state_dict, strict=False)
return model