Q-bert commited on
Commit
f3bd2f6
1 Parent(s): 50371ba

Upload 2 files

Browse files
Files changed (2) hide show
  1. configuration_stockllama.py +65 -0
  2. modeling_stockllama.py +140 -0
configuration_stockllama.py ADDED
@@ -0,0 +1,65 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers.configuration_utils import PretrainedConfig
2
+
3
+ class StockLlamaConfig(PretrainedConfig):
4
+ model_type = "stockllama"
5
+ keys_to_ignore_at_inference = ["past_key_values"]
6
+
7
+ def __init__(
8
+ self,
9
+ vocab_size=32000,
10
+ hidden_size=4096,
11
+ intermediate_size=11008,
12
+ num_hidden_layers=32,
13
+ num_attention_heads=32,
14
+ num_key_value_heads=None,
15
+ hidden_act="silu",
16
+ max_position_embeddings=2048,
17
+ term_number=4,
18
+ initializer_range=0.02,
19
+ rms_norm_eps=1e-6,
20
+ use_cache=True,
21
+ pad_token_id=None,
22
+ bos_token_id=1,
23
+ eos_token_id=2,
24
+ pretraining_tp=1,
25
+ tie_word_embeddings=False,
26
+ rope_theta=10000.0,
27
+ rope_scaling=None,
28
+ attention_bias=False,
29
+ attention_dropout=0.0,
30
+ mlp_bias=False,
31
+ head_dim=None,
32
+ **kwargs,
33
+ ):
34
+ self.vocab_size = vocab_size
35
+ self.max_position_embeddings = max_position_embeddings
36
+ self.term_number = term_number
37
+ self.hidden_size = hidden_size
38
+ self.intermediate_size = intermediate_size
39
+ self.num_hidden_layers = num_hidden_layers
40
+ self.num_attention_heads = num_attention_heads
41
+
42
+ if num_key_value_heads is None:
43
+ num_key_value_heads = num_attention_heads
44
+
45
+ self.num_key_value_heads = num_key_value_heads
46
+ self.hidden_act = hidden_act
47
+ self.initializer_range = initializer_range
48
+ self.rms_norm_eps = rms_norm_eps
49
+ self.pretraining_tp = pretraining_tp
50
+ self.use_cache = use_cache
51
+ self.rope_theta = rope_theta
52
+ self.rope_scaling = rope_scaling
53
+ self.attention_bias = attention_bias
54
+ self.attention_dropout = attention_dropout
55
+ self.mlp_bias = mlp_bias
56
+ self.head_dim = head_dim if head_dim is not None else self.hidden_size // self.num_attention_heads
57
+
58
+
59
+ super().__init__(
60
+ pad_token_id=pad_token_id,
61
+ bos_token_id=bos_token_id,
62
+ eos_token_id=eos_token_id,
63
+ tie_word_embeddings=tie_word_embeddings,
64
+ **kwargs,
65
+ )
modeling_stockllama.py ADDED
@@ -0,0 +1,140 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torch import nn
2
+ from torch.nn import functional as F
3
+ import torch
4
+
5
+ from configuration_stockllama import StockLlamaConfig
6
+
7
+ from transformers.models.llama.modeling_llama import LlamaPreTrainedModel
8
+ from transformers.models.llama.modeling_llama import LlamaModel
9
+ from transformers.modeling_utils import PreTrainedModel
10
+ from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, SequenceClassifierOutputWithPast
11
+ from transformers.cache_utils import Cache
12
+ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
13
+ import math
14
+ from typing import Any, Dict, List, Optional, Tuple , Union
15
+
16
+ class FloatEmbedding(nn.Module):
17
+ def __init__(self, vocab_size, hidden_size, padding_idx ,term_number):
18
+ super(FloatEmbedding, self).__init__()
19
+ self.term_number = term_number
20
+ self.int_part = nn.Embedding(vocab_size, hidden_size ,padding_idx)
21
+ self.float_part = nn.Embedding(10**term_number , hidden_size)
22
+
23
+ def forward(self, input):
24
+ float_input = ((input - torch.floor(input)) * (10**self.term_number)).to(torch.long)
25
+ int_input = input.to(torch.long)
26
+ output = self.float_part(float_input) + self.int_part(int_input)
27
+
28
+ return output
29
+
30
+ class StockLlamaPreTrainedModel(LlamaPreTrainedModel):
31
+ config_class = StockLlamaConfig
32
+ base_model_prefix = "model"
33
+ supports_gradient_checkpointing = True
34
+ _no_split_modules = ["LlamaDecoderLayer"]
35
+ _skip_keys_device_placement = ["past_key_values"]
36
+ _supports_flash_attn_2 = True
37
+ _supports_sdpa = True
38
+ _supports_cache_class = True
39
+ _supports_quantized_cache = True
40
+ _supports_static_cache = True
41
+
42
+ def _init_weights(self, module):
43
+ std = self.config.initializer_range
44
+ if isinstance(module, nn.Linear):
45
+ module.weight.data.normal_(mean=0.0, std=std)
46
+ if module.bias is not None:
47
+ module.bias.data.zero_()
48
+ elif isinstance(module, nn.Embedding):
49
+ module.weight.data.normal_(mean=0.0, std=std)
50
+ if module.padding_idx is not None:
51
+ module.weight.data[module.padding_idx].zero_()
52
+
53
+ class StockLlamaModel(LlamaModel):
54
+ config_class = StockLlamaConfig
55
+
56
+ def __init__(self, config):
57
+ super().__init__(config)
58
+ self._use_flash_attention_2 = True
59
+ self.embed_tokens = FloatEmbedding(config.vocab_size, config.hidden_size, self.padding_idx, config.term_number)
60
+ self.post_init()
61
+
62
+
63
+ class StockLlamaForForecasting(StockLlamaPreTrainedModel):
64
+ def __init__(self, config):
65
+ super().__init__(config)
66
+ self.model = StockLlamaModel(config)
67
+ self.score = nn.Linear(config.hidden_size, 1, bias=False)
68
+ self.post_init()
69
+
70
+ def get_input_embeddings(self):
71
+ return self.model.embed_tokens
72
+
73
+ def set_input_embeddings(self, value):
74
+ self.model.embed_tokens = value
75
+
76
+ def forward(
77
+ self,
78
+ input_ids: Optional[torch.LongTensor] = None,
79
+ attention_mask: Optional[torch.Tensor] = None,
80
+ position_ids: Optional[torch.LongTensor] = None,
81
+ past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
82
+ inputs_embeds: Optional[torch.FloatTensor] = None,
83
+ labels: Optional[torch.FloatTensor] = None,
84
+ use_cache: Optional[bool] = None,
85
+ output_attentions: Optional[bool] = None,
86
+ output_hidden_states: Optional[bool] = None,
87
+ return_dict: Optional[bool] = None,
88
+ ) -> Union[Tuple, SequenceClassifierOutputWithPast]:
89
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
90
+
91
+ transformer_outputs = self.model(
92
+ input_ids,
93
+ attention_mask=attention_mask,
94
+ position_ids=position_ids,
95
+ past_key_values=past_key_values,
96
+ inputs_embeds=inputs_embeds,
97
+ use_cache=use_cache,
98
+ output_attentions=output_attentions,
99
+ output_hidden_states=output_hidden_states,
100
+ return_dict=return_dict,
101
+ )
102
+ hidden_states = transformer_outputs[0]
103
+ logits = self.score(hidden_states)
104
+
105
+ if input_ids is not None:
106
+ batch_size = input_ids.shape[0]
107
+ else:
108
+ batch_size = inputs_embeds.shape[0]
109
+
110
+ if self.config.pad_token_id is None and batch_size != 1:
111
+ raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.")
112
+ if self.config.pad_token_id is None:
113
+ sequence_lengths = -1
114
+ else:
115
+ if input_ids is not None:
116
+ sequence_lengths = torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1
117
+ sequence_lengths = sequence_lengths % input_ids.shape[-1]
118
+ sequence_lengths = sequence_lengths.to(logits.device)
119
+ else:
120
+ sequence_lengths = -1
121
+
122
+ pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths]
123
+
124
+ loss = None
125
+ if labels is not None:
126
+ labels = labels.to(logits.device)
127
+ loss_fct = MSELoss()
128
+ loss = loss_fct(pooled_logits.squeeze(), labels.squeeze())
129
+
130
+ if not return_dict:
131
+ output = (pooled_logits,) + transformer_outputs[1:]
132
+ return ((loss,) + output) if loss is not None else output
133
+
134
+ return SequenceClassifierOutputWithPast(
135
+ loss=loss,
136
+ logits=pooled_logits,
137
+ past_key_values=transformer_outputs.past_key_values,
138
+ hidden_states=transformer_outputs.hidden_states,
139
+ attentions=transformer_outputs.attentions,
140
+ )