|
import torch
|
|
from torch import nn
|
|
from transformers import PreTrainedModel, PretrainedConfig
|
|
from safetensors.torch import load_file
|
|
|
|
|
|
from .modeling_clipPT import CLIPVisionTransformer
|
|
from transformers import CLIPImageProcessor
|
|
|
|
from transformers import AutoTokenizer
|
|
|
|
from .modeling_qwen2 import Qwen2Model
|
|
|
|
|
|
from .modeling_timer import TimerForPrediction
|
|
|
|
class MulTiCastTimerConfig(PretrainedConfig):
|
|
def __init__(
|
|
self,
|
|
forecasting_length = None,
|
|
vision_model_name = None,
|
|
text_model_name = None,
|
|
vision_model_prompt_len = None,
|
|
text_model_prompt_len = None,
|
|
timer_prompt_len = None,
|
|
**kwargs
|
|
):
|
|
super().__init__(**kwargs)
|
|
self.forecasting_length = forecasting_length
|
|
self.vision_model_name = vision_model_name
|
|
self.text_model_name = text_model_name
|
|
|
|
self.vision_model_prompt_len = vision_model_prompt_len if vision_model_prompt_len is not None else 10
|
|
self.text_model_prompt_len = text_model_prompt_len if text_model_prompt_len is not None else 4
|
|
self.timer_prompt_len = timer_prompt_len if timer_prompt_len is not None else 4
|
|
|
|
class MulTiCastTimerModel(PreTrainedModel):
|
|
|
|
config_class = MulTiCastTimerConfig
|
|
|
|
def __init__(self, config):
|
|
super().__init__(config)
|
|
self.config = config
|
|
|
|
|
|
if config.vision_model_name is None:
|
|
pass
|
|
elif config.vision_model_name == 'CLIP':
|
|
from transformers import AutoModel
|
|
vision_model = AutoModel.from_pretrained("openai/clip-vit-base-patch32").vision_model
|
|
state_dict = vision_model.state_dict()
|
|
state_dict = {k: v.to(torch.bfloat16) for k, v in state_dict.items()}
|
|
self.vision_model = CLIPVisionTransformer(vision_model.config, config.vision_model_prompt_len)
|
|
self.vision_model.load_state_dict(state_dict, strict=False)
|
|
self.processor = CLIPImageProcessor()
|
|
for name, param in self.vision_model.named_parameters():
|
|
if "encoder.prompts" in name:
|
|
param.requires_grad = True
|
|
else:
|
|
param.requires_grad = False
|
|
else:
|
|
pass
|
|
|
|
|
|
if config.text_model_name is None:
|
|
pass
|
|
elif config.text_model_name == 'Qwen':
|
|
self.tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2-1.5B-Instruct")
|
|
from transformers import AutoModelForCausalLM
|
|
text_model = AutoModelForCausalLM.from_pretrained(
|
|
"Qwen/Qwen2-1.5B-Instruct",
|
|
torch_dtype=torch.bfloat16,
|
|
device_map="cpu",
|
|
attn_implementation="sdpa"
|
|
).model
|
|
state_dict = text_model.state_dict()
|
|
self.text_model = Qwen2Model(text_model.config, config.text_model_prompt_len)
|
|
self.text_model.load_state_dict(state_dict, strict=False)
|
|
for name, param in self.text_model.named_parameters():
|
|
if "prompts" in name:
|
|
param.requires_grad = True
|
|
else:
|
|
param.requires_grad = False
|
|
else:
|
|
pass
|
|
|
|
|
|
from transformers import AutoModelForCausalLM
|
|
timer = AutoModelForCausalLM.from_pretrained('thuml/timer-base-84m', trust_remote_code=True)
|
|
state_dict = timer.state_dict()
|
|
self.timer = TimerForPrediction(timer.config, config.timer_prompt_len)
|
|
self.timer.load_state_dict(state_dict, strict=False)
|
|
for name, param in self.timer.named_parameters():
|
|
if "model.prompts" in name:
|
|
param.requires_grad = True
|
|
else:
|
|
param.requires_grad = False
|
|
|
|
|
|
if config.vision_model_name is None:
|
|
pass
|
|
else:
|
|
self.vision_interaction_layer = nn.Linear(self.vision_model.config.hidden_size, self.timer.config.hidden_size)
|
|
|
|
|
|
if config.text_model_name is None:
|
|
pass
|
|
else:
|
|
self.text_interaction_layer = nn.Linear(self.text_model.config.hidden_size, self.timer.config.hidden_size)
|
|
|
|
def predict(self, input_ids = None, images = None, texts = None):
|
|
images = self.processor.preprocess(images)['pixel_values'][0]
|
|
images = torch.tensor(images)
|
|
images = images.unsqueeze(0)
|
|
|
|
if self.config.vision_model_name is None and images is None:
|
|
vision_embedding = None
|
|
else:
|
|
vision_output = self.vision_model(images, output_attentions=True)
|
|
vision_attentions = vision_output.attentions
|
|
vision_embedding = vision_output.pooler_output
|
|
vision_embedding = self.vision_interaction_layer(vision_embedding)
|
|
|
|
if self.config.text_model_name is None and all(x is None for x in texts):
|
|
text_embedding = None
|
|
else:
|
|
tokenized_texts = self.tokenizer(texts, return_tensors="pt")
|
|
text_embedding = self.text_model(**tokenized_texts)
|
|
text_embedding = text_embedding.last_hidden_state[:, 0 , :]
|
|
text_embedding = self.text_interaction_layer(text_embedding)
|
|
|
|
out = self.timer(input_ids=input_ids, vision_embedding=vision_embedding, text_embedding=text_embedding)
|
|
|
|
return {
|
|
"logits": out.logits,
|
|
"vision_attentions": vision_attentions,
|
|
"time_series_attentions": out.attentions
|
|
}
|
|
|
|
def forward(self, input_ids = None, images = None, texts = None, labels = None):
|
|
if self.config.vision_model_name is None and images is None:
|
|
vision_embedding = None
|
|
else:
|
|
vision_embedding = self.vision_model(images)
|
|
vision_embedding = vision_embedding.pooler_output
|
|
vision_embedding = self.vision_interaction_layer(vision_embedding)
|
|
|
|
if self.config.text_model_name is None and all(x is None for x in texts):
|
|
text_embedding = None
|
|
else:
|
|
tokenized_texts = self.tokenizer(texts, return_tensors="pt")
|
|
text_embedding = self.text_model(**tokenized_texts)
|
|
text_embedding = text_embedding.last_hidden_state[:, 0 , :]
|
|
text_embedding = self.text_interaction_layer(text_embedding)
|
|
|
|
out = self.timer(input_ids=input_ids, vision_embedding=vision_embedding, text_embedding=text_embedding)
|
|
out = out["logits"]
|
|
|
|
if labels is not None:
|
|
if self.config.forecasting_length == out.shape[-1]:
|
|
loss = torch.mean(torch.square(out-labels))
|
|
else:
|
|
loss = torch.mean(torch.square(out[:, :self.config.forecasting_length]-labels))
|
|
else:
|
|
loss = None
|
|
|
|
return {
|
|
"loss": loss,
|
|
"logits": out
|
|
}
|
|
|
|
@classmethod
|
|
def from_pretrained(cls, pretrained_model_name_or_path, **kwargs):
|
|
from transformers.utils import cached_file
|
|
config = MulTiCastTimerConfig.from_pretrained(pretrained_model_name_or_path)
|
|
model = MulTiCastTimerModel(config)
|
|
resolved_file = cached_file(pretrained_model_name_or_path, "model.safetensors")
|
|
state_dict = load_file(resolved_file)
|
|
model.load_state_dict(state_dict, strict=False)
|
|
|
|
return model |