mansaripo commited on
Commit
f849730
·
verified ·
1 Parent(s): 9d9ad39

Upload folder using huggingface_hub

Browse files
Files changed (3) hide show
  1. config.json +1 -4
  2. modeling_cloverlm.py +3 -15
  3. tokenizer_config.json +1 -4
config.json CHANGED
@@ -6,10 +6,7 @@
6
  "auto_map": {
7
  "AutoConfig": "configuration_cloverlm.CloverLMConfig",
8
  "AutoModelForCausalLM": "modeling_cloverlm.CloverLMForCausalLM",
9
- "AutoTokenizer": [
10
- "tokenization_cloverlm.CloverLMTokenizer",
11
- null
12
- ]
13
  },
14
  "d_head": 128,
15
  "heads": 28,
 
6
  "auto_map": {
7
  "AutoConfig": "configuration_cloverlm.CloverLMConfig",
8
  "AutoModelForCausalLM": "modeling_cloverlm.CloverLMForCausalLM",
9
+ "AutoTokenizer": "tokenization_cloverlm.CloverLMTokenizer"
 
 
 
10
  },
11
  "d_head": 128,
12
  "heads": 28,
modeling_cloverlm.py CHANGED
@@ -111,27 +111,15 @@ class MHSA(nn.Module):
111
 
112
  dtype = Q.dtype if Q.dtype in (torch.bfloat16, torch.float16) else torch.bfloat16
113
  if attn_backend == "flash2":
114
- try:
115
- import flash_attn
116
- except ImportError as e:
117
- e.add_note(f"Can't run `attn_backend=flash2` because can't import flash_attn")
118
- raise e
119
  Y = flash_attn.flash_attn_func(Q.to(dtype), K.to(dtype), V.to(dtype), causal=True, softmax_scale=1.0)
120
  elif attn_backend == "flash3":
121
  import importlib
122
- try:
123
- _fa3 = importlib.import_module("flash_attn_interface")
124
- except ImportError as e:
125
- e.add_note(f"Can't run `attn_backend=flash3` because can't import flash_attn_interface")
126
- raise e
127
  Y = _fa3.flash_attn_func(Q.to(dtype), K.to(dtype), V.to(dtype), causal=True, softmax_scale=1.0)
128
  elif attn_backend == "flash4":
129
  import importlib
130
- try:
131
- _fa4 = importlib.import_module("flash_attn.cute")
132
- except ImportError as e:
133
- e.add_note(f"Can't run `attn_backend=flash4` because can't import flash_attn.cute")
134
- raise e
135
  Y = _fa4.flash_attn_func(Q.to(dtype), K.to(dtype), V.to(dtype), causal=True, softmax_scale=1.0)[0]
136
  Y = Y.to(Q.dtype).flatten(-2, -1)
137
 
 
111
 
112
  dtype = Q.dtype if Q.dtype in (torch.bfloat16, torch.float16) else torch.bfloat16
113
  if attn_backend == "flash2":
114
+ import flash_attn
 
 
 
 
115
  Y = flash_attn.flash_attn_func(Q.to(dtype), K.to(dtype), V.to(dtype), causal=True, softmax_scale=1.0)
116
  elif attn_backend == "flash3":
117
  import importlib
118
+ _fa3 = importlib.import_module("flash_attn_interface")
 
 
 
 
119
  Y = _fa3.flash_attn_func(Q.to(dtype), K.to(dtype), V.to(dtype), causal=True, softmax_scale=1.0)
120
  elif attn_backend == "flash4":
121
  import importlib
122
+ _fa4 = importlib.import_module("flash_attn.cute")
 
 
 
 
123
  Y = _fa4.flash_attn_func(Q.to(dtype), K.to(dtype), V.to(dtype), causal=True, softmax_scale=1.0)[0]
124
  Y = Y.to(Q.dtype).flatten(-2, -1)
125
 
tokenizer_config.json CHANGED
@@ -1,10 +1,7 @@
1
  {
2
  "tokenizer_class": "CloverLMTokenizer",
3
  "auto_map": {
4
- "AutoTokenizer": [
5
- "tokenization_cloverlm.CloverLMTokenizer",
6
- null
7
- ]
8
  },
9
  "use_fast": false
10
  }
 
1
  {
2
  "tokenizer_class": "CloverLMTokenizer",
3
  "auto_map": {
4
+ "AutoTokenizer": "tokenization_cloverlm.CloverLMTokenizer"
 
 
 
5
  },
6
  "use_fast": false
7
  }