|
--- |
|
license: artistic-2.0 |
|
datasets: |
|
- Siddharth63/biological_dataset |
|
- Siddharth63/clinical_dataset |
|
--- |
|
|
|
BitNEt 250 M trained on 7B tokens on PubMed + Clinical dataset |
|
|
|
Inference code: |
|
``` |
|
from transformers import AutoModelForCausalLM, AutoTokenizer |
|
from transformers.models.llama.modeling_llama import * |
|
|
|
# Load a pretrained BitNet model |
|
model = "Siddharth63/Bitnet-250M" |
|
tokenizer = AutoTokenizer.from_pretrained(model) |
|
model = AutoModelForCausalLM.from_pretrained(model) |
|
|
|
|
|
def activation_quant(x): |
|
scale = 127.0 / x.abs().max(dim=-1, keepdim=True).values.clamp_(min=1e-5) |
|
y = (x * scale).round().clamp_(-128, 127) |
|
y = y / scale |
|
return y |
|
def weight_quant(w): |
|
scale = 1.0 / w.abs().mean().clamp_(min=1e-5) |
|
u = (w * scale).round().clamp_(-1, 1) |
|
u = u / scale |
|
return u |
|
|
|
class BitLinear(nn.Linear): |
|
def forward(self, x): |
|
w = self.weight # a weight tensor with shape [d, k] |
|
x = x.to(w.device) |
|
RMSNorm = LlamaRMSNorm(x.shape[-1]).to(w.device) |
|
x_norm = RMSNorm(x) |
|
# A trick for implementing Straight−Through−Estimator (STE) using detach() |
|
x_quant = x_norm + (activation_quant(x_norm) - x_norm).detach() |
|
w_quant = w + (weight_quant(w) - w).detach() |
|
y = F.linear(x_quant, w_quant) |
|
return y |
|
|
|
def convert_to_bitnet(model, copy_weights): |
|
for name, module in model.named_modules(): |
|
# Replace linear layers with BitNet |
|
if isinstance(module, LlamaSdpaAttention) or isinstance(module, LlamaMLP): |
|
for child_name, child_module in module.named_children(): |
|
if isinstance(child_module, nn.Linear): |
|
bitlinear = BitLinear(child_module.in_features, child_module.out_features, child_module.bias is not None).to(device="cuda:0") |
|
if copy_weights: |
|
bitlinear.weight = child_module.weight |
|
if child_module.bias is not None: |
|
bitlinear.bias = child_module.bias |
|
setattr(module, child_name, bitlinear) |
|
# Remove redundant input_layernorms |
|
elif isinstance(module, LlamaDecoderLayer): |
|
for child_name, child_module in module.named_children(): |
|
if isinstance(child_module, LlamaRMSNorm) and child_name == "input_layernorm": |
|
setattr(module, child_name, nn.Identity().to(device="cuda:0")) |
|
|
|
|
|
convert_to_bitnet(model, copy_weights=True) |
|
model.to(device="cuda:0") |
|
|
|
prompt = "Atherosclerosis is" |
|
inputs = tokenizer(prompt, return_tensors="pt").to(model.device) |
|
generate_ids = model.generate(inputs.input_ids, max_length=50) |
|
tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] |
|
|
|
``` |