Triton version

#23
by JiayiJennie - opened

May I know what's the triton version for running the model? I used triton 2.2.0 here.

When I try to load input into the model as the demo code shows: hidden_states = model(inputs)[0]
It occurs error below:

File /opt/local/stow/pip-3.10/local/lib/python3.10/dist-packages/triton/compiler/code_generator.py:1228, in ast_to_ttir(fn, signature, specialization, constants, debug, target)
1227 try:
-> 1228 generator.visit(fn.parse())
1229 except CompilationError as e:

File /opt/local/stow/pip-3.10/local/lib/python3.10/dist-packages/triton/compiler/code_generator.py:1105, in CodeGenerator.visit(self, node)
1104 last_loc = self.builder.get_loc()
-> 1105 ret = super().visit(node)
1106 # Reset the location to the last one before the visit

File ~/miniconda3/envs/DNABERT/lib/python3.10/ast.py:418, in NodeVisitor.visit(self, node)
417 visitor = getattr(self, method, self.generic_visit)
--> 418 return visitor(node)

File /opt/local/stow/pip-3.10/local/lib/python3.10/dist-packages/triton/compiler/code_generator.py:303, in CodeGenerator.visit_Module(self, node)
302 def visit_Module(self, node):
--> 303 ast.NodeVisitor.generic_visit(self, node)

File ~/miniconda3/envs/DNABERT/lib/python3.10/ast.py:426, in NodeVisitor.generic_visit(self, node)
425 if isinstance(item, AST):
--> 426 self.visit(item)
427 elif isinstance(value, AST):

File /opt/local/stow/pip-3.10/local/lib/python3.10/dist-packages/triton/compiler/code_generator.py:1105, in CodeGenerator.visit(self, node)
1104 last_loc = self.builder.get_loc()
-> 1105 ret = super().visit(node)
1106 # Reset the location to the last one before the visit

File ~/miniconda3/envs/DNABERT/lib/python3.10/ast.py:418, in NodeVisitor.visit(self, node)
417 visitor = getattr(self, method, self.generic_visit)
--> 418 return visitor(node)

File /opt/local/stow/pip-3.10/local/lib/python3.10/dist-packages/triton/compiler/code_generator.py:376, in CodeGenerator.visit_FunctionDef(self, node)
375 # visit function body
--> 376 self.visit_compound_statement(node.body)
377 # finalize function

File /opt/local/stow/pip-3.10/local/lib/python3.10/dist-packages/triton/compiler/code_generator.py:298, in CodeGenerator.visit_compound_statement(self, stmts)
297 for stmt in stmts:
--> 298 ret_type = self.visit(stmt)
299 if ret_type is not None and isinstance(stmt, ast.Return):

File /opt/local/stow/pip-3.10/local/lib/python3.10/dist-packages/triton/compiler/code_generator.py:1105, in CodeGenerator.visit(self, node)
1104 last_loc = self.builder.get_loc()
-> 1105 ret = super().visit(node)
1106 # Reset the location to the last one before the visit

File ~/miniconda3/envs/DNABERT/lib/python3.10/ast.py:418, in NodeVisitor.visit(self, node)
417 visitor = getattr(self, method, self.generic_visit)
--> 418 return visitor(node)

File /opt/local/stow/pip-3.10/local/lib/python3.10/dist-packages/triton/compiler/code_generator.py:885, in CodeGenerator.visit_For(self, node)
884 self.scf_stack.append(node)
--> 885 self.visit_compound_statement(node.body)
886 self.scf_stack.pop()

File /opt/local/stow/pip-3.10/local/lib/python3.10/dist-packages/triton/compiler/code_generator.py:298, in CodeGenerator.visit_compound_statement(self, stmts)
297 for stmt in stmts:
--> 298 ret_type = self.visit(stmt)
299 if ret_type is not None and isinstance(stmt, ast.Return):

File /opt/local/stow/pip-3.10/local/lib/python3.10/dist-packages/triton/compiler/code_generator.py:1105, in CodeGenerator.visit(self, node)
1104 last_loc = self.builder.get_loc()
-> 1105 ret = super().visit(node)
1106 # Reset the location to the last one before the visit

File ~/miniconda3/envs/DNABERT/lib/python3.10/ast.py:418, in NodeVisitor.visit(self, node)
417 visitor = getattr(self, method, self.generic_visit)
--> 418 return visitor(node)

File /opt/local/stow/pip-3.10/local/lib/python3.10/dist-packages/triton/compiler/code_generator.py:448, in CodeGenerator.visit_AugAssign(self, node)
447 assign = ast.Assign(targets=[node.target], value=rhs)
--> 448 self.visit(assign)
449 return self.dereference_name(name)

File /opt/local/stow/pip-3.10/local/lib/python3.10/dist-packages/triton/compiler/code_generator.py:1105, in CodeGenerator.visit(self, node)
1104 last_loc = self.builder.get_loc()
-> 1105 ret = super().visit(node)
1106 # Reset the location to the last one before the visit

File ~/miniconda3/envs/DNABERT/lib/python3.10/ast.py:418, in NodeVisitor.visit(self, node)
417 visitor = getattr(self, method, self.generic_visit)
--> 418 return visitor(node)

File /opt/local/stow/pip-3.10/local/lib/python3.10/dist-packages/triton/compiler/code_generator.py:428, in CodeGenerator.visit_Assign(self, node)
427 names = _names[0]
--> 428 values = self.visit(node.value)
429 if not _is_list_like(names):

File /opt/local/stow/pip-3.10/local/lib/python3.10/dist-packages/triton/compiler/code_generator.py:1105, in CodeGenerator.visit(self, node)
1104 last_loc = self.builder.get_loc()
-> 1105 ret = super().visit(node)
1106 # Reset the location to the last one before the visit

File ~/miniconda3/envs/DNABERT/lib/python3.10/ast.py:418, in NodeVisitor.visit(self, node)
417 visitor = getattr(self, method, self.generic_visit)
--> 418 return visitor(node)

File /opt/local/stow/pip-3.10/local/lib/python3.10/dist-packages/triton/compiler/code_generator.py:477, in CodeGenerator.visit_BinOp(self, node)
476 lhs = self.visit(node.left)
--> 477 rhs = self.visit(node.right)
478 method_name = self._method_name_for_bin_op.get(type(node.op))

File /opt/local/stow/pip-3.10/local/lib/python3.10/dist-packages/triton/compiler/code_generator.py:1105, in CodeGenerator.visit(self, node)
1104 last_loc = self.builder.get_loc()
-> 1105 ret = super().visit(node)
1106 # Reset the location to the last one before the visit

File ~/miniconda3/envs/DNABERT/lib/python3.10/ast.py:418, in NodeVisitor.visit(self, node)
417 visitor = getattr(self, method, self.generic_visit)
--> 418 return visitor(node)

File /opt/local/stow/pip-3.10/local/lib/python3.10/dist-packages/triton/compiler/code_generator.py:1027, in CodeGenerator.visit_Call(self, node)
1026 extra_kwargs['_generator'] = self
-> 1027 return fn(*args, **extra_kwargs, **kws)
1028 if fn in self.builtin_namespace.values():

File /opt/local/stow/pip-3.10/local/lib/python3.10/dist-packages/triton/language/core.py:27, in builtin..wrapper(*args, **kwargs)
25 raise ValueError("Did you forget to add @triton .jit ? "
26 "(_builder argument must be provided outside of JIT functions.)")
---> 27 return fn(*args, **kwargs)

TypeError: dot() got an unexpected keyword argument 'trans_b'

The above exception was the direct cause of the following exception:

CompilationError Traceback (most recent call last)
Cell In[2], line 8
5 model.to(device)
6 inputs = inputs.to(device)
----> 8 hidden_states = model(inputs)[0] # [1, sequence_length, 768]
10 # embedding with mean pooling
11 # embedding_mean = torch.mean(hidden_states[0], dim=0)
12 # print(embedding_mean.shape) # expect to be 768

File /opt/local/stow/pip-3.10/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1511, in Module._wrapped_call_impl(self, *args, **kwargs)
1509 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
1510 else:
-> 1511 return self._call_impl(*args, **kwargs)

File /opt/local/stow/pip-3.10/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1520, in Module._call_impl(self, *args, **kwargs)
1515 # If we don't have any hooks, we want to skip the rest of the logic in
1516 # this function, and just call forward.
1517 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
1518 or _global_backward_pre_hooks or _global_backward_hooks
1519 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1520 return forward_call(*args, **kwargs)
1522 try:
1523 result = None

File ~/.cache/huggingface/modules/transformers_modules/zhihan1996/DNABERT-S/1cdf84d992ace6f3e75c7356774b4da088c8dc7c/bert_layers.py:608, in BertModel.forward(self, input_ids, token_type_ids, attention_mask, position_ids, output_all_encoded_layers, masked_tokens_mask, **kwargs)
605 first_col_mask[:, 0] = True
606 subset_mask = masked_tokens_mask | first_col_mask
--> 608 encoder_outputs = self.encoder(
609 embedding_output,
610 attention_mask,
611 output_all_encoded_layers=output_all_encoded_layers,
612 subset_mask=subset_mask)
614 if masked_tokens_mask is None:
615 sequence_output = encoder_outputs[-1]

File /opt/local/stow/pip-3.10/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1511, in Module._wrapped_call_impl(self, *args, **kwargs)
1509 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
1510 else:
-> 1511 return self._call_impl(*args, **kwargs)

File /opt/local/stow/pip-3.10/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1520, in Module._call_impl(self, *args, **kwargs)
1515 # If we don't have any hooks, we want to skip the rest of the logic in
1516 # this function, and just call forward.
1517 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
1518 or _global_backward_pre_hooks or _global_backward_hooks
1519 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1520 return forward_call(*args, **kwargs)
1522 try:
1523 result = None

File ~/.cache/huggingface/modules/transformers_modules/zhihan1996/DNABERT-S/1cdf84d992ace6f3e75c7356774b4da088c8dc7c/bert_layers.py:446, in BertEncoder.forward(self, hidden_states, attention_mask, output_all_encoded_layers, subset_mask)
444 if subset_mask is None:
445 for layer_module in self.layer:
--> 446 hidden_states = layer_module(hidden_states,
447 cu_seqlens,
448 seqlen,
449 None,
450 indices,
451 attn_mask=attention_mask,
452 bias=alibi_attn_mask)
453 if output_all_encoded_layers:
454 all_encoder_layers.append(hidden_states)

File /opt/local/stow/pip-3.10/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1511, in Module._wrapped_call_impl(self, *args, **kwargs)
1509 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
1510 else:
-> 1511 return self._call_impl(*args, **kwargs)

File /opt/local/stow/pip-3.10/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1520, in Module._call_impl(self, *args, **kwargs)
1515 # If we don't have any hooks, we want to skip the rest of the logic in
1516 # this function, and just call forward.
1517 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
1518 or _global_backward_pre_hooks or _global_backward_hooks
1519 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1520 return forward_call(*args, **kwargs)
1522 try:
1523 result = None

File ~/.cache/huggingface/modules/transformers_modules/zhihan1996/DNABERT-S/1cdf84d992ace6f3e75c7356774b4da088c8dc7c/bert_layers.py:327, in BertLayer.forward(self, hidden_states, cu_seqlens, seqlen, subset_idx, indices, attn_mask, bias)
305 def forward(
306 self,
307 hidden_states: torch.Tensor,
(...)
313 bias: Optional[torch.Tensor] = None,
314 ) -> torch.Tensor:
315 """Forward pass for a BERT layer, including both attention and MLP.
316
317 Args:
(...)
325 bias: None or (batch, heads, max_seqlen_in_batch, max_seqlen_in_batch)
326 """
--> 327 attention_output = self.attention(hidden_states, cu_seqlens, seqlen,
328 subset_idx, indices, attn_mask, bias)
329 layer_output = self.mlp(attention_output)
330 return layer_output

File /opt/local/stow/pip-3.10/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1511, in Module._wrapped_call_impl(self, *args, **kwargs)
1509 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
1510 else:
-> 1511 return self._call_impl(*args, **kwargs)

File /opt/local/stow/pip-3.10/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1520, in Module._call_impl(self, *args, **kwargs)
1515 # If we don't have any hooks, we want to skip the rest of the logic in
1516 # this function, and just call forward.
1517 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
1518 or _global_backward_pre_hooks or _global_backward_hooks
1519 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1520 return forward_call(*args, **kwargs)
1522 try:
1523 result = None

File ~/.cache/huggingface/modules/transformers_modules/zhihan1996/DNABERT-S/1cdf84d992ace6f3e75c7356774b4da088c8dc7c/bert_layers.py:240, in BertUnpadAttention.forward(self, input_tensor, cu_seqlens, max_s, subset_idx, indices, attn_mask, bias)
218 def forward(
219 self,
220 input_tensor: torch.Tensor,
(...)
226 bias: Optional[torch.Tensor] = None,
227 ) -> torch.Tensor:
228 """Forward pass for scaled self-attention without padding.
229
230 Arguments:
(...)
238 bias: None or (batch, heads, max_seqlen_in_batch, max_seqlen_in_batch)
239 """
--> 240 self_output = self.self(input_tensor, cu_seqlens, max_s, indices,
241 attn_mask, bias)
242 if subset_idx is not None:
243 return self.output(index_first_axis(self_output, subset_idx),
244 index_first_axis(input_tensor, subset_idx))

File /opt/local/stow/pip-3.10/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1511, in Module._wrapped_call_impl(self, *args, **kwargs)
1509 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
1510 else:
-> 1511 return self._call_impl(*args, **kwargs)

File /opt/local/stow/pip-3.10/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1520, in Module._call_impl(self, *args, **kwargs)
1515 # If we don't have any hooks, we want to skip the rest of the logic in
1516 # this function, and just call forward.
1517 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
1518 or _global_backward_pre_hooks or _global_backward_hooks
1519 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1520 return forward_call(*args, **kwargs)
1522 try:
1523 result = None

File ~/.cache/huggingface/modules/transformers_modules/zhihan1996/DNABERT-S/1cdf84d992ace6f3e75c7356774b4da088c8dc7c/bert_layers.py:181, in BertUnpadSelfAttention.forward(self, hidden_states, cu_seqlens, max_seqlen_in_batch, indices, attn_mask, bias)
179 bias_dtype = bias.dtype
180 bias = bias.to(torch.float16)
--> 181 attention = flash_attn_qkvpacked_func(qkv, bias)
182 attention = attention.to(orig_dtype)
183 bias = bias.to(bias_dtype)

File /opt/local/stow/pip-3.10/local/lib/python3.10/dist-packages/torch/autograd/function.py:553, in Function.apply(cls, *args, **kwargs)
550 if not torch._C._are_functorch_transforms_active():
551 # See NOTE: [functorch vjp and autograd interaction]
552 args = _functorch.utils.unwrap_dead_wrappers(args)
--> 553 return super().apply(*args, **kwargs) # type: ignore[misc]
555 if not is_setup_ctx_defined:
556 raise RuntimeError(
557 "In order to use an autograd.Function with functorch transforms "
558 "(vmap, grad, jvp, jacrev, ...), it must override the setup_context "
559 "staticmethod. For more details, please see "
560 "https://pytorch.org/docs/master/notes/extending.func.html"
561 )

File ~/.cache/huggingface/modules/transformers_modules/zhihan1996/DNABERT-S/1cdf84d992ace6f3e75c7356774b4da088c8dc7c/flash_attn_triton.py:1021, in _FlashAttnQKVPackedFunc.forward(ctx, qkv, bias, causal, softmax_scale)
1019 if qkv.stride(-1) != 1:
1020 qkv = qkv.contiguous()
-> 1021 o, lse, ctx.softmax_scale = _flash_attn_forward(
1022 qkv[:, :, 0],
1023 qkv[:, :, 1],
1024 qkv[:, :, 2],
1025 bias=bias,
1026 causal=causal,
1027 softmax_scale=softmax_scale)
1028 ctx.save_for_backward(qkv, o, lse, bias)
1029 ctx.causal = causal

File ~/.cache/huggingface/modules/transformers_modules/zhihan1996/DNABERT-S/1cdf84d992ace6f3e75c7356774b4da088c8dc7c/flash_attn_triton.py:826, in _flash_attn_forward(q, k, v, bias, causal, softmax_scale)
823 # BLOCK = 128
824 # num_warps = 4 if d <= 64 else 8
825 grid = lambda META: (triton.cdiv(seqlen_q, META['BLOCK_M']), batch * nheads)
--> 826 _fwd_kernel[grid]( # type: ignore
827 q,
828 k,
829 v,
830 bias,
831 o,
832 lse,
833 tmp,
834 softmax_scale,
835 q.stride(0),
836 q.stride(2),
837 q.stride(1),
838 k.stride(0),
839 k.stride(2),
840 k.stride(1),
841 v.stride(0),
842 v.stride(2),
843 v.stride(1),
844 *bias_strides,
845 o.stride(0),
846 o.stride(2),
847 o.stride(1),
848 nheads,
849 seqlen_q,
850 seqlen_k,
851 seqlen_q_rounded,
852 d,
853 seqlen_q // 32,
854 seqlen_k // 32, # key for triton cache (limit number of compilations)
855 # Can't use kwargs here because triton autotune expects key to be args, not kwargs
856 # IS_CAUSAL=causal, BLOCK_HEADDIM=d,
857 bias_type,
858 causal,
859 BLOCK_HEADDIM,
860 # BLOCK_M=BLOCK, BLOCK_N=BLOCK,
861 # num_warps=num_warps,
862 # num_stages=1,
863 )
864 return o, lse, softmax_scale

File /opt/local/stow/pip-3.10/local/lib/python3.10/dist-packages/triton/runtime/autotuner.py:156, in Autotuner.run(self, *args, **kwargs)
154 if config.pre_hook is not None:
155 config.pre_hook(full_nargs)
--> 156 ret = self.fn.run(
157 *args,
158 num_warps=config.num_warps,
159 num_stages=config.num_stages,
160 num_ctas=config.num_ctas,
161 enable_warp_specialization=config.enable_warp_specialization,
162 **kwargs,
163 **config.kwargs,
164 )
165 self.nargs = None
166 return ret

File /opt/local/stow/pip-3.10/local/lib/python3.10/dist-packages/triton/runtime/autotuner.py:305, in Heuristics.run(self, *args, **kwargs)
303 for v, heur in self.values.items():
304 kwargs[v] = heur({**dict(zip(self.arg_names, args)), **kwargs})
--> 305 return self.fn.run(*args, **kwargs)

File /opt/local/stow/pip-3.10/local/lib/python3.10/dist-packages/triton/runtime/jit.py:532, in JITFunction.run(self, *args, **kwargs)
517 if self._call_hook(
518 key,
519 signature,
(...)
528 configs,
529 ):
530 return None
--> 532 self.cache[device][key] = compile(
533 self,
534 signature=signature,
535 device=device,
536 constants=constants,
537 num_warps=num_warps,
538 num_ctas=num_ctas,
539 num_stages=num_stages,
540 enable_warp_specialization=enable_warp_specialization,
541 enable_fp_fusion=enable_fp_fusion,
542 extern_libs=extern_libs,
543 configs=configs,
544 debug=self.debug,
545 device_type=device_type,
546 )
548 bin = self.cache[device][key]
549 if not warmup:

File /opt/local/stow/pip-3.10/local/lib/python3.10/dist-packages/triton/compiler/compiler.py:543, in compile(fn, **kwargs)
541 path = metadata_group.get(ir_filename)
542 if path is None:
--> 543 next_module = compile_kernel(module)
544 if ir_name == "amdgcn":
545 extra_file_name = f"{name}.hsaco_path"

File /opt/local/stow/pip-3.10/local/lib/python3.10/dist-packages/triton/compiler/compiler.py:435, in compile..(src)
432 stages = dict()
433 stages["ast"] = (lambda path: fn, None)
434 stages["ttir"] = (lambda path: parse_mlir_module(path, context), lambda src: optimize_ttir(
--> 435 ast_to_ttir(src, signature, configs[0], constants, debug=debug, target=target), target))
436 if is_cuda:
437 stages["ttgir"] = (lambda path: parse_mlir_module(path, context), lambda src: optimize_ttgir(
438 ttir_to_ttgir(src, num_warps, num_ctas, target), num_stages, num_warps, num_ctas, target, cluster_info,
439 enable_warp_specialization, enable_persistent, optimize_epilogue))

File /opt/local/stow/pip-3.10/local/lib/python3.10/dist-packages/triton/compiler/code_generator.py:1237, in ast_to_ttir(fn, signature, specialization, constants, debug, target)
1235 if node is None:
1236 raise
-> 1237 raise CompilationError(fn.src, node, repr(e)) from e
1238 ret = generator.module
1239 # module takes ownership of the context

CompilationError: at 114:24: else:
if EVEN_HEADDIM:
k = tl.load(k_ptrs + start_n * stride_kn,
mask=(start_n + offs_n)[:, None] < seqlen_k,
other=0.0)
else:
k = tl.load(k_ptrs + start_n * stride_kn,
mask=((start_n + offs_n)[:, None] < seqlen_k) &
(offs_d[None, :] < headdim),
other=0.0)
qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
qk += tl.dot(q, k, trans_b=True)
^
TypeError("dot() got an unexpected keyword argument 'trans_b'")

solved problem by using triton==2.0.0.dev20221202

I have a similar error

"TypeError("dot() got an unexpected keyword argument 'trans_b'")"

even if I set triton==2.0.0.dev20221202 when I input:

dna = "ACGTAGCATCGGATCTATCTATCGAC"

if torch.cuda.is_available():
model.to('cuda')
inputs = tokenizer(dna, return_tensors='pt')['input_ids'].to('cuda')
else:
inputs = tokenizer(dna, return_tensors='pt')['input_ids']

hidden_states = model(inputs)[0] # [1, sequence_length, 768]

I have run into the same problem as @YiXW . Are you able to solve it?

This comment has been hidden

Sign up or log in to comment