Siddharth63 commited on
Commit
e8028e4
1 Parent(s): ca7af1c

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +68 -3
README.md CHANGED
@@ -1,3 +1,68 @@
1
- ---
2
- license: artistic-2.0
3
- ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: artistic-2.0
3
+ datasets:
4
+ - Siddharth63/biological_dataset
5
+ - Siddharth63/clinical_dataset
6
+ ---
7
+
8
+ BitNEt 250 M trained on 7B tokens on PubMed + Clinical dataset
9
+
10
+ Inference code:
11
+ from transformers import AutoModelForCausalLM, AutoTokenizer
12
+ from transformers.models.llama.modeling_llama import *
13
+
14
+ # Load a pretrained BitNet model
15
+ model = "Siddharth63/Bitnet-250M"
16
+ tokenizer = AutoTokenizer.from_pretrained(model)
17
+ model = AutoModelForCausalLM.from_pretrained(model)
18
+
19
+
20
+ def activation_quant(x):
21
+ scale = 127.0 / x.abs().max(dim=-1, keepdim=True).values.clamp_(min=1e-5)
22
+ y = (x * scale).round().clamp_(-128, 127)
23
+ y = y / scale
24
+ return y
25
+ def weight_quant(w):
26
+ scale = 1.0 / w.abs().mean().clamp_(min=1e-5)
27
+ u = (w * scale).round().clamp_(-1, 1)
28
+ u = u / scale
29
+ return u
30
+
31
+ class BitLinear(nn.Linear):
32
+ def forward(self, x):
33
+ w = self.weight # a weight tensor with shape [d, k]
34
+ x = x.to(w.device)
35
+ RMSNorm = LlamaRMSNorm(x.shape[-1]).to(w.device)
36
+ x_norm = RMSNorm(x)
37
+ # A trick for implementing Straight−Through−Estimator (STE) using detach()
38
+ x_quant = x_norm + (activation_quant(x_norm) - x_norm).detach()
39
+ w_quant = w + (weight_quant(w) - w).detach()
40
+ y = F.linear(x_quant, w_quant)
41
+ return y
42
+
43
+ def convert_to_bitnet(model, copy_weights):
44
+ for name, module in model.named_modules():
45
+ # Replace linear layers with BitNet
46
+ if isinstance(module, LlamaSdpaAttention) or isinstance(module, LlamaMLP):
47
+ for child_name, child_module in module.named_children():
48
+ if isinstance(child_module, nn.Linear):
49
+ bitlinear = BitLinear(child_module.in_features, child_module.out_features, child_module.bias is not None).to(device="cuda:0")
50
+ if copy_weights:
51
+ bitlinear.weight = child_module.weight
52
+ if child_module.bias is not None:
53
+ bitlinear.bias = child_module.bias
54
+ setattr(module, child_name, bitlinear)
55
+ # Remove redundant input_layernorms
56
+ elif isinstance(module, LlamaDecoderLayer):
57
+ for child_name, child_module in module.named_children():
58
+ if isinstance(child_module, LlamaRMSNorm) and child_name == "input_layernorm":
59
+ setattr(module, child_name, nn.Identity().to(device="cuda:0"))
60
+
61
+
62
+ convert_to_bitnet(model, copy_weights=True)
63
+ model.to(device="cuda:0")
64
+
65
+ prompt = "Atherosclerosis is"
66
+ inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
67
+ generate_ids = model.generate(inputs.input_ids, max_length=50)
68
+ tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]