triton error when convert to onnx
#114
by
hoailebads
- opened
I have trouble when convert model that i have finetuned to onnx format. The errors show that
Traceback (most recent call last):
File "/validation/convert_module/main.py", line 20, in <module>
onnx_converter.convert(onnx_path)
File "/validation/convert_module/onnx_converter.py", line 41, in convert
torch.onnx.export(
File "/usr/local/lib/python3.10/dist-packages/torch/onnx/__init__.py", line 377, in export
export(
File "/usr/local/lib/python3.10/dist-packages/torch/onnx/utils.py", line 502, in export
_export(
File "/usr/local/lib/python3.10/dist-packages/torch/onnx/utils.py", line 1564, in _export
graph, params_dict, torch_out = _model_to_graph(
File "/usr/local/lib/python3.10/dist-packages/torch/onnx/utils.py", line 1113, in _model_to_graph
graph, params, torch_out, module = _create_jit_graph(model, args)
File "/usr/local/lib/python3.10/dist-packages/torch/onnx/utils.py", line 997, in _create_jit_graph
graph, torch_out = _trace_and_get_graph_from_model(model, args)
File "/usr/local/lib/python3.10/dist-packages/torch/onnx/utils.py", line 904, in _trace_and_get_graph_from_model
trace_graph, torch_out, inputs_states = torch.jit._get_trace_graph(
File "/usr/local/lib/python3.10/dist-packages/torch/jit/_trace.py", line 1500, in _get_trace_graph
outs = ONNXTracedModule(
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1747, in _call_impl
return forward_call(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/jit/_trace.py", line 139, in forward
graph, out = torch._C._create_graph_by_tracing(
File "/usr/local/lib/python3.10/dist-packages/torch/jit/_trace.py", line 130, in wrapper
outs.append(self.inner(*trace_inputs))
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1747, in _call_impl
return forward_call(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1726, in _slow_forward
result = self.forward(*input, **kwargs)
File "/validation/convert_module/xlm_roberta_lora/modeling_lora.py", line 374, in forward
return self.roberta(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1747, in _call_impl
return forward_call(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1726, in _slow_forward
result = self.forward(*input, **kwargs)
File "/validation/convert_module/xlm_roberta_lora/modeling_xlm_roberta.py", line 736, in forward
sequence_output = self.encoder(
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1747, in _call_impl
return forward_call(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1726, in _slow_forward
result = self.forward(*input, **kwargs)
File "/validation/convert_module/xlm_roberta_lora/modeling_xlm_roberta.py", line 230, in forward
hidden_states = layer(hidden_states, mixer_kwargs=mixer_kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1747, in _call_impl
return forward_call(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1726, in _slow_forward
result = self.forward(*input, **kwargs)
File "/validation/convert_module/xlm_roberta_lora/block.py", line 201, in forward
mixer_out = self.mixer(
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1747, in _call_impl
return forward_call(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1726, in _slow_forward
result = self.forward(*input, **kwargs)
File "/validation/convert_module/xlm_roberta_lora/mha.py", line 732, in forward
qkv = self.rotary_emb(
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1747, in _call_impl
return forward_call(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1726, in _slow_forward
result = self.forward(*input, **kwargs)
File "/validation/convert_module/xlm_roberta_lora/rotary.py", line 604, in forward
return apply_rotary_emb_qkv_(
File "/validation/convert_module/xlm_roberta_lora/rotary.py", line 327, in apply_rotary_emb_qkv_
return ApplyRotaryEmbQKV_.apply(
File "/usr/local/lib/python3.10/dist-packages/torch/autograd/function.py", line 575, in apply
return super().apply(*args, **kwargs) # type: ignore[misc]
File "/validation/convert_module/xlm_roberta_lora/rotary.py", line 186, in forward
apply_rotary(
File "/usr/local/lib/python3.10/dist-packages/flash_attn/ops/triton/rotary.py", line 213, in apply_rotary
rotary_kernel[grid](
File "/usr/local/lib/python3.10/dist-packages/triton/runtime/jit.py", line 345, in <lambda>
return lambda *args, **kwargs: self.run(grid=grid, warmup=False, *args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/triton/runtime/jit.py", line 662, in run
kernel = self.compile(
File "/usr/local/lib/python3.10/dist-packages/triton/compiler/compiler.py", line 276, in compile
module = src.make_ir(options, codegen_fns, context)
File "/usr/local/lib/python3.10/dist-packages/triton/compiler/compiler.py", line 113, in make_ir
return ast_to_ttir(self.fn, self, context=context, options=options, codegen_fns=codegen_fns)
triton.compiler.errors.CompilationError: at 34:22:
# Meta-parameters
BLOCK_K: tl.constexpr,
IS_SEQLEN_OFFSETS_TENSOR: tl.constexpr,
IS_VARLEN: tl.constexpr,
INTERLEAVED: tl.constexpr,
CONJUGATE: tl.constexpr,
BLOCK_M: tl.constexpr,
):
pid_m = tl.program_id(axis=0)
pid_batch = tl.program_id(axis=1)
pid_head = tl.program_id(axis=2)
rotary_dim_half = rotary_dim // 2
^
IncompatibleTypeErrorImpl('invalid operands of type pointer<int64> and triton.language.int32')
my package
pytorch-triton 3.0.0+dedb7bdf3
torch 2.5.0a0+e000cf0ad9.nv24.10
torch_tensorrt 2.5.0a0
torchprofile 0.0.4
torchvision 0.20.0a0
sentence-transformers 3.4.1
transformers 4.48.3
flash_attn 2.4.2
onnx 1.16.2
Please let me know if you guy have fixed this issue before