Update README.md
Browse files
README.md
CHANGED
|
@@ -21,13 +21,12 @@ from huggingface_hub import snapshot_download
|
|
| 21 |
from os.path import join as pjoin
|
| 22 |
from safetensors import safe_open
|
| 23 |
|
| 24 |
-
#Simple FP4 dequant matmul
|
| 25 |
@torch.compile(fullgraph=True)
|
| 26 |
def matmul_fp4(x, W_q, scales, group_size, fp4_values):
|
| 27 |
def unpack_over_cols(W_q_packed, W_nbits, num_output_cols, dtype):
|
| 28 |
n_rows, n_cols = W_q_packed.shape
|
| 29 |
-
device
|
| 30 |
-
shifts = torch.arange(num_output_cols // n_cols, device=device, dtype=dtype) * W_nbits
|
| 31 |
W_q_unpacked = ((W_q_packed.unsqueeze(-1) >> shifts) & ((1 << W_nbits) - 1)).to(dtype)
|
| 32 |
W_q_unpacked = W_q_unpacked.view(n_rows, num_output_cols)
|
| 33 |
return W_q_unpacked
|
|
@@ -38,6 +37,7 @@ def matmul_fp4(x, W_q, scales, group_size, fp4_values):
|
|
| 38 |
return torch.matmul(x, W_r)
|
| 39 |
|
| 40 |
class AutoModelForCausalLMFP4:
|
|
|
|
| 41 |
@classmethod
|
| 42 |
def from_pretrained(
|
| 43 |
cls,
|
|
@@ -48,6 +48,7 @@ class AutoModelForCausalLMFP4:
|
|
| 48 |
*args,
|
| 49 |
**kwargs
|
| 50 |
):
|
|
|
|
| 51 |
#Download snapshot
|
| 52 |
if os.path.exists(save_dir_or_hub):
|
| 53 |
save_dir = save_dir_or_hub
|
|
|
|
| 21 |
from os.path import join as pjoin
|
| 22 |
from safetensors import safe_open
|
| 23 |
|
|
|
|
| 24 |
@torch.compile(fullgraph=True)
|
| 25 |
def matmul_fp4(x, W_q, scales, group_size, fp4_values):
|
| 26 |
def unpack_over_cols(W_q_packed, W_nbits, num_output_cols, dtype):
|
| 27 |
n_rows, n_cols = W_q_packed.shape
|
| 28 |
+
device = W_q_packed.device
|
| 29 |
+
shifts = torch.arange(num_output_cols // n_cols, device=device, dtype=W_q_packed.dtype) * W_nbits
|
| 30 |
W_q_unpacked = ((W_q_packed.unsqueeze(-1) >> shifts) & ((1 << W_nbits) - 1)).to(dtype)
|
| 31 |
W_q_unpacked = W_q_unpacked.view(n_rows, num_output_cols)
|
| 32 |
return W_q_unpacked
|
|
|
|
| 37 |
return torch.matmul(x, W_r)
|
| 38 |
|
| 39 |
class AutoModelForCausalLMFP4:
|
| 40 |
+
|
| 41 |
@classmethod
|
| 42 |
def from_pretrained(
|
| 43 |
cls,
|
|
|
|
| 48 |
*args,
|
| 49 |
**kwargs
|
| 50 |
):
|
| 51 |
+
|
| 52 |
#Download snapshot
|
| 53 |
if os.path.exists(save_dir_or_hub):
|
| 54 |
save_dir = save_dir_or_hub
|