add kernel file check in modeling_qwen.py
Browse files- modeling_qwen.py +14 -4
modeling_qwen.py
CHANGED
@@ -6,11 +6,13 @@
|
|
6 |
import copy
|
7 |
import importlib
|
8 |
import math
|
|
|
9 |
from typing import TYPE_CHECKING, Optional, Tuple, Union, Callable, List, Any, Generator
|
10 |
|
11 |
import torch
|
12 |
import torch.nn.functional as F
|
13 |
import torch.utils.checkpoint
|
|
|
14 |
from torch.cuda.amp import autocast
|
15 |
|
16 |
from torch.nn import CrossEntropyLoss
|
@@ -295,11 +297,19 @@ class QWenAttention(nn.Module):
|
|
295 |
self.cache_qmin = torch.tensor(torch.iinfo(torch.uint8).min, dtype=cache_dtype)
|
296 |
|
297 |
if config.use_cache_quantization and config.use_cache_kernel:
|
298 |
-
|
299 |
-
|
300 |
-
|
301 |
-
|
|
|
302 |
self.cache_kernels = None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
303 |
|
304 |
def _attn(self, query, key, value, registered_causal_mask, attention_mask=None, head_mask=None):
|
305 |
device = query.device
|
|
|
6 |
import copy
|
7 |
import importlib
|
8 |
import math
|
9 |
+
import pathlib
|
10 |
from typing import TYPE_CHECKING, Optional, Tuple, Union, Callable, List, Any, Generator
|
11 |
|
12 |
import torch
|
13 |
import torch.nn.functional as F
|
14 |
import torch.utils.checkpoint
|
15 |
+
import warnings
|
16 |
from torch.cuda.amp import autocast
|
17 |
|
18 |
from torch.nn import CrossEntropyLoss
|
|
|
297 |
self.cache_qmin = torch.tensor(torch.iinfo(torch.uint8).min, dtype=cache_dtype)
|
298 |
|
299 |
if config.use_cache_quantization and config.use_cache_kernel:
|
300 |
+
# pre check if the support files existing
|
301 |
+
module_root = pathlib.Path(__file__).parent
|
302 |
+
src_files = ("cache_autogptq_cuda_256.cpp", "cache_autogptq_cuda_kernel_256.cu")
|
303 |
+
if any(not (module_root/src).is_file() for src in src_files):
|
304 |
+
warnings.warn("KV cache kernel source files (.cpp and .cu) not found.")
|
305 |
self.cache_kernels = None
|
306 |
+
else:
|
307 |
+
try:
|
308 |
+
from .cpp_kernels import cache_autogptq_cuda_256
|
309 |
+
self.cache_kernels = cache_autogptq_cuda_256
|
310 |
+
except ImportError:
|
311 |
+
warnings.warn("Failed to import KV cache kernels.")
|
312 |
+
self.cache_kernels = None
|
313 |
|
314 |
def _attn(self, query, key, value, registered_causal_mask, attention_mask=None, head_mask=None):
|
315 |
device = query.device
|