mobicham commited on
Commit
eebd230
·
verified ·
1 Parent(s): 90bfd94

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +4 -3
README.md CHANGED
@@ -21,13 +21,12 @@ from huggingface_hub import snapshot_download
21
  from os.path import join as pjoin
22
  from safetensors import safe_open
23
 
24
- #Simple FP4 dequant matmul
25
  @torch.compile(fullgraph=True)
26
  def matmul_fp4(x, W_q, scales, group_size, fp4_values):
27
  def unpack_over_cols(W_q_packed, W_nbits, num_output_cols, dtype):
28
  n_rows, n_cols = W_q_packed.shape
29
- device, dtype = W_q_packed.device, W_q_packed.dtype
30
- shifts = torch.arange(num_output_cols // n_cols, device=device, dtype=dtype) * W_nbits
31
  W_q_unpacked = ((W_q_packed.unsqueeze(-1) >> shifts) & ((1 << W_nbits) - 1)).to(dtype)
32
  W_q_unpacked = W_q_unpacked.view(n_rows, num_output_cols)
33
  return W_q_unpacked
@@ -38,6 +37,7 @@ def matmul_fp4(x, W_q, scales, group_size, fp4_values):
38
  return torch.matmul(x, W_r)
39
 
40
  class AutoModelForCausalLMFP4:
 
41
  @classmethod
42
  def from_pretrained(
43
  cls,
@@ -48,6 +48,7 @@ class AutoModelForCausalLMFP4:
48
  *args,
49
  **kwargs
50
  ):
 
51
  #Download snapshot
52
  if os.path.exists(save_dir_or_hub):
53
  save_dir = save_dir_or_hub
 
21
  from os.path import join as pjoin
22
  from safetensors import safe_open
23
 
 
24
  @torch.compile(fullgraph=True)
25
  def matmul_fp4(x, W_q, scales, group_size, fp4_values):
26
  def unpack_over_cols(W_q_packed, W_nbits, num_output_cols, dtype):
27
  n_rows, n_cols = W_q_packed.shape
28
+ device = W_q_packed.device
29
+ shifts = torch.arange(num_output_cols // n_cols, device=device, dtype=W_q_packed.dtype) * W_nbits
30
  W_q_unpacked = ((W_q_packed.unsqueeze(-1) >> shifts) & ((1 << W_nbits) - 1)).to(dtype)
31
  W_q_unpacked = W_q_unpacked.view(n_rows, num_output_cols)
32
  return W_q_unpacked
 
37
  return torch.matmul(x, W_r)
38
 
39
  class AutoModelForCausalLMFP4:
40
+
41
  @classmethod
42
  def from_pretrained(
43
  cls,
 
48
  *args,
49
  **kwargs
50
  ):
51
+
52
  #Download snapshot
53
  if os.path.exists(save_dir_or_hub):
54
  save_dir = save_dir_or_hub