|
from typing import Callable, Optional |
|
|
|
import torch |
|
from torch.library import Library |
|
|
|
from .._ops import ops |
|
from ..platforms import current_platform |
|
|
|
|
|
|
|
|
|
def supports_custom_op() -> bool: |
|
return hasattr(torch.library, "custom_op") |
|
|
|
|
|
|
|
vllm_lib = Library(ops.__name__.split(".")[-1], "FRAGMENT") |
|
|
|
|
|
def direct_register_custom_op( |
|
op_name: str, |
|
op_func: Callable, |
|
mutates_args: list[str], |
|
fake_impl: Optional[Callable] = None, |
|
target_lib: Optional[Library] = None, |
|
dispatch_key: str = "CUDA", |
|
): |
|
""" |
|
`torch.library.custom_op` can have significant overhead because it |
|
needs to consider complicated dispatching logic. This function |
|
directly registers a custom op and dispatches it to the CUDA backend. |
|
See https://gist.github.com/youkaichao/ecbea9ec9fc79a45d2adce1784d7a9a5 |
|
for more details. |
|
|
|
By default, the custom op is registered to the vLLM library. If you |
|
want to register it to a different library, you can pass the library |
|
object to the `target_lib` argument. |
|
|
|
IMPORTANT: the lifetime of the operator is tied to the lifetime of the |
|
library object. If you want to bind the operator to a different library, |
|
make sure the library object is alive when the operator is used. |
|
""" |
|
if not supports_custom_op(): |
|
from ..platforms import current_platform |
|
|
|
assert not current_platform.is_cuda_alike(), ( |
|
"cuda platform needs torch>=2.4 to support custom op, " |
|
"chances are you are using an old version of pytorch " |
|
"or a custom build of pytorch. It is recommended to " |
|
"use vLLM in a fresh new environment and let it install " |
|
"the required dependencies." |
|
) |
|
return |
|
|
|
import torch.library |
|
|
|
if hasattr(torch.library, "infer_schema"): |
|
schema_str = torch.library.infer_schema(op_func, mutates_args=mutates_args) |
|
else: |
|
|
|
import torch._custom_op.impl |
|
|
|
schema_str = torch._custom_op.impl.infer_schema(op_func, mutates_args) |
|
my_lib = target_lib or vllm_lib |
|
my_lib.define(op_name + schema_str) |
|
my_lib.impl(op_name, op_func, dispatch_key=dispatch_key) |
|
if fake_impl is not None: |
|
my_lib._register_fake(op_name, fake_impl) |
|
|