tomer-deci commited on
Commit
49422ae
1 Parent(s): 64cf789

added support for text-generation pipeline

Browse files
Files changed (3) hide show
  1. config.json +2 -1
  2. modeling_decilm.py +9 -10
  3. version_check.py +3 -3
config.json CHANGED
@@ -25,5 +25,6 @@
25
  "use_bfloat16": true,
26
  "transformers_version": "4.35.2",
27
  "use_cache": true,
28
- "vocab_size": 32000
 
29
  }
 
25
  "use_bfloat16": true,
26
  "transformers_version": "4.35.2",
27
  "use_cache": true,
28
+ "vocab_size": 32000,
29
+ "tokenizer_class": "LlamaTokenizer"
30
  }
modeling_decilm.py CHANGED
@@ -1,5 +1,5 @@
1
  # coding=utf-8
2
- # Copyright and license here
3
  """ PyTorch DeciLM model."""
4
  from .version_check import check_transformers_version
5
 
@@ -11,18 +11,18 @@ import torch
11
  import torch.nn.functional as F
12
  import torch.utils.checkpoint
13
  from torch import nn
 
 
14
 
 
 
15
  from .transformers_v4_35_2__modeling_llama import LlamaMLP, LlamaRMSNorm, LlamaAttention, apply_rotary_pos_emb, \
16
  repeat_kv, LlamaPreTrainedModel, LLAMA_START_DOCSTRING, LlamaDecoderLayer, LlamaForCausalLM, LlamaModel, \
17
  BaseModelOutputWithPast, LLAMA_INPUTS_DOCSTRING
18
- from .transformers_v4_35_2__modeling_attn_mask_utils import _prepare_4d_causal_attention_mask
19
- from transformers.utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging
20
-
21
- from .configuration_decilm import DeciLMConfig
22
-
23
- logger = logging.get_logger(__name__)
24
 
 
25
  _CONFIG_FOR_DOC = "DeciLMConfig"
 
26
 
27
 
28
  class DeciLMAttention(LlamaAttention):
@@ -61,7 +61,7 @@ class DeciLMAttention(LlamaAttention):
61
  past_key_value: Optional[Tuple[torch.Tensor]] = None,
62
  output_attentions: bool = False,
63
  use_cache: bool = False,
64
- # padding_mask: Optional[torch.LongTensor] = None,
65
  ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
66
  bsz, q_len, _ = hidden_states.size()
67
  is_decode = past_key_value is not None
@@ -135,8 +135,7 @@ class DeciLMAttention(LlamaAttention):
135
  else:
136
  attn_output = self.o_proj(attn_output)
137
 
138
- if not output_attentions:
139
- attn_weights = None
140
 
141
  return attn_output, attn_weights, past_key_value
142
 
 
1
  # coding=utf-8
2
+ # Copyright and license in the repo.
3
  """ PyTorch DeciLM model."""
4
  from .version_check import check_transformers_version
5
 
 
11
  import torch.nn.functional as F
12
  import torch.utils.checkpoint
13
  from torch import nn
14
+ from transformers.models.auto.modeling_auto import MODEL_FOR_CAUSAL_LM_MAPPING_NAMES
15
+ from transformers.utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging
16
 
17
+ from .configuration_decilm import DeciLMConfig
18
+ from .transformers_v4_35_2__modeling_attn_mask_utils import _prepare_4d_causal_attention_mask
19
  from .transformers_v4_35_2__modeling_llama import LlamaMLP, LlamaRMSNorm, LlamaAttention, apply_rotary_pos_emb, \
20
  repeat_kv, LlamaPreTrainedModel, LLAMA_START_DOCSTRING, LlamaDecoderLayer, LlamaForCausalLM, LlamaModel, \
21
  BaseModelOutputWithPast, LLAMA_INPUTS_DOCSTRING
 
 
 
 
 
 
22
 
23
+ MODEL_FOR_CAUSAL_LM_MAPPING_NAMES["deci"] = "DeciLMForCausalLM"
24
  _CONFIG_FOR_DOC = "DeciLMConfig"
25
+ logger = logging.get_logger(__name__)
26
 
27
 
28
  class DeciLMAttention(LlamaAttention):
 
61
  past_key_value: Optional[Tuple[torch.Tensor]] = None,
62
  output_attentions: bool = False,
63
  use_cache: bool = False,
64
+ **kwargs,
65
  ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
66
  bsz, q_len, _ = hidden_states.size()
67
  is_decode = past_key_value is not None
 
135
  else:
136
  attn_output = self.o_proj(attn_output)
137
 
138
+ attn_weights = None
 
139
 
140
  return attn_output, attn_weights, past_key_value
141
 
version_check.py CHANGED
@@ -1,11 +1,11 @@
1
  import transformers
2
  from packaging import version
3
 
4
- VERSION = "4.35.2"
5
 
6
 
7
  def check_transformers_version():
8
- if version.parse(transformers.__version__) < version.parse(VERSION):
9
  raise ImportError(
10
- f"You are using transformers=={transformers.__version__}, but transformers>={VERSION} is required to use DeciLM. Please upgrade transformers."
11
  )
 
1
  import transformers
2
  from packaging import version
3
 
4
+ MIN_VERSION = "4.35.2"
5
 
6
 
7
  def check_transformers_version():
8
+ if version.parse(transformers.__version__) < version.parse(MIN_VERSION):
9
  raise ImportError(
10
+ f"You are using transformers=={transformers.__version__}, but transformers>={MIN_VERSION} is required to use DeciLM. Please upgrade transformers."
11
  )