How to actually run the model without getting run-time errors?

#4
by Maykeye - opened

I tried as per readme first.

from transformers import AutoModelForMaskedLM, BertTokenizer, pipeline

tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
mlm = AutoModelForMaskedLM.from_pretrained('mosaicml/mosaic-bert-base', trust_remote_code=True, 
   revision='24512df') # I tried  with or without revision

classifier = pipeline('fill-mask', model=mlm, tokenizer=tokenizer)

classifier("I [MASK] to the store yesterday.")

The example is not working.

File ~/.cache/huggingface/modules/transformers_modules/mosaicml/mosaic-bert-base/fcc434c97e2d475d5dd1a69fca9f734af7a41772/flash_attn_triton.py:781, in _flash_attn_forward(q, k, v, bias, causal, softmax_scale)
    778 assert q.dtype == k.dtype == v.dtype, 'All tensors must have the same type'
    779 assert q.dtype in [torch.float16,
    780                    torch.bfloat16], 'Only support fp16 and bf16'
--> 781 assert q.is_cuda and k.is_cuda and v.is_cuda

This is trivial to fix:

mlm = AutoModelForMaskedLM.from_pretrained('mosaicml/mosaic-bert-base', trust_remote_code=True, revision='24512df').cuda()
classifier = pipeline('fill-mask', model=mlm, tokenizer=tokenizer,device="cuda:0")
classifier("I [MASK] to the store yesterday.")

And ...

KeyError                                  Traceback (most recent call last)
File <string>:21, in _fwd_kernel(Q, K, V, Bias, Out, Lse, TMP, softmax_scale, stride_qb, stride_qh, stride_qm, stride_kb, stride_kh, stride_kn, stride_vb, stride_vh, stride_vn, stride_bb, stride_bh, stride_bm, stride_ob, stride_oh, stride_om, nheads, seqlen_q, seqlen_k, seqlen_q_rounded, headdim, CACHE_KEY_SEQLEN_Q, CACHE_KEY_SEQLEN_K, BIAS_TYPE, IS_CAUSAL, BLOCK_HEADDIM, EVEN_M, EVEN_N, EVEN_HEADDIM, BLOCK_M, BLOCK_N, grid, num_warps, num_stages, extern_libs, stream, warmup)

KeyError: ('2-.-0-.-0-83ca8b715a9dc5f32dc1110973485f64-d6252949da17ceb5f3a278a70250af13-3b85c7bef5f0a641282f3b73af50f599-975a5a907f067e8e36a802ec0cd5bc10-3498c340fd4b6ee7805fd54b882a04f5-e1f133f98d04093da2078dfc51c36b72-b26258bf01f839199e39d64851821f26-d7c06e3b46e708006c15224aac7a1378-f585402118c8a136948ce0a49cfe122c', (torch.float16, torch.float16, torch.float16, torch.float16, torch.float16, torch.float32, torch.float32, 'fp32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32'), ('matrix', False, 64, False, False, True, 128, 128), (True, True, True, True, True, True, True, (False,), (True, False), (True, False), (True, False), (True, False), (True, False), (True, False), (True, False), (True, False), (True, False), (False, False), (False, False), (False, False), (True, False), (True, False), (True, False), (False, False), (False, False), (False, False), (True, False), (True, False), (True, False), (True, False)))

During handling of the above exception, another exception occurred:

TypeError                                 Traceback (most recent call last)
File ~/src/sd/sd/lib/python3.11/site-packages/triton/compiler.py:937, in build_triton_ir(fn, signature, specialization, constants)
    936 try:
--> 937     generator.visit(fn.parse())
    938 except Exception as e:

File ~/src/sd/sd/lib/python3.11/site-packages/triton/compiler.py:855, in CodeGenerator.visit(self, node)
    854 warnings.simplefilter("ignore", PendingDeprecationWarning)  # python 3.8
--> 855 return super().visit(node)

File /usr/lib/python3.11/ast.py:418, in NodeVisitor.visit(self, node)
    417 visitor = getattr(self, method, self.generic_visit)
--> 418 return visitor(node)

File ~/src/sd/sd/lib/python3.11/site-packages/triton/compiler.py:183, in CodeGenerator.visit_Module(self, node)
    182 def visit_Module(self, node):
--> 183     ast.NodeVisitor.generic_visit(self, node)

File /usr/lib/python3.11/ast.py:426, in NodeVisitor.generic_visit(self, node)
    425         if isinstance(item, AST):
--> 426             self.visit(item)
    427 elif isinstance(value, AST):

File ~/src/sd/sd/lib/python3.11/site-packages/triton/compiler.py:855, in CodeGenerator.visit(self, node)
    854 warnings.simplefilter("ignore", PendingDeprecationWarning)  # python 3.8
--> 855 return super().visit(node)

File /usr/lib/python3.11/ast.py:418, in NodeVisitor.visit(self, node)
    417 visitor = getattr(self, method, self.generic_visit)
--> 418 return visitor(node)

File ~/src/sd/sd/lib/python3.11/site-packages/triton/compiler.py:252, in CodeGenerator.visit_FunctionDef(self, node)
    251 # visit function body
--> 252 has_ret = self.visit_compound_statement(node.body)
    253 # finalize function

File ~/src/sd/sd/lib/python3.11/site-packages/triton/compiler.py:177, in CodeGenerator.visit_compound_statement(self, stmts)
    176 for stmt in stmts:
--> 177     self.last_ret_type = self.visit(stmt)
    178     if isinstance(stmt, ast.Return):

File ~/src/sd/sd/lib/python3.11/site-packages/triton/compiler.py:855, in CodeGenerator.visit(self, node)
    854 warnings.simplefilter("ignore", PendingDeprecationWarning)  # python 3.8
--> 855 return super().visit(node)

File /usr/lib/python3.11/ast.py:418, in NodeVisitor.visit(self, node)
    417 visitor = getattr(self, method, self.generic_visit)
--> 418 return visitor(node)

File ~/src/sd/sd/lib/python3.11/site-packages/triton/compiler.py:678, in CodeGenerator.visit_For(self, node)
    677 self.scf_stack.append(node)
--> 678 self.visit_compound_statement(node.body)
    679 self.scf_stack.pop()

File ~/src/sd/sd/lib/python3.11/site-packages/triton/compiler.py:177, in CodeGenerator.visit_compound_statement(self, stmts)
    176 for stmt in stmts:
--> 177     self.last_ret_type = self.visit(stmt)
    178     if isinstance(stmt, ast.Return):

File ~/src/sd/sd/lib/python3.11/site-packages/triton/compiler.py:855, in CodeGenerator.visit(self, node)
    854 warnings.simplefilter("ignore", PendingDeprecationWarning)  # python 3.8
--> 855 return super().visit(node)

File /usr/lib/python3.11/ast.py:418, in NodeVisitor.visit(self, node)
    417 visitor = getattr(self, method, self.generic_visit)
--> 418 return visitor(node)

File ~/src/sd/sd/lib/python3.11/site-packages/triton/compiler.py:319, in CodeGenerator.visit_AugAssign(self, node)
    318 assign = ast.Assign(targets=[node.target], value=rhs)
--> 319 self.visit(assign)
    320 return self.get_value(name)

File ~/src/sd/sd/lib/python3.11/site-packages/triton/compiler.py:855, in CodeGenerator.visit(self, node)
    854 warnings.simplefilter("ignore", PendingDeprecationWarning)  # python 3.8
--> 855 return super().visit(node)

File /usr/lib/python3.11/ast.py:418, in NodeVisitor.visit(self, node)
    417 visitor = getattr(self, method, self.generic_visit)
--> 418 return visitor(node)

File ~/src/sd/sd/lib/python3.11/site-packages/triton/compiler.py:301, in CodeGenerator.visit_Assign(self, node)
    300 names = _names[0]
--> 301 values = self.visit(node.value)
    302 if not isinstance(names, tuple):

File ~/src/sd/sd/lib/python3.11/site-packages/triton/compiler.py:855, in CodeGenerator.visit(self, node)
    854 warnings.simplefilter("ignore", PendingDeprecationWarning)  # python 3.8
--> 855 return super().visit(node)

File /usr/lib/python3.11/ast.py:418, in NodeVisitor.visit(self, node)
    417 visitor = getattr(self, method, self.generic_visit)
--> 418 return visitor(node)

File ~/src/sd/sd/lib/python3.11/site-packages/triton/compiler.py:339, in CodeGenerator.visit_BinOp(self, node)
    338 lhs = self.visit(node.left)
--> 339 rhs = self.visit(node.right)
    340 fn = {
    341     ast.Add: '__add__',
    342     ast.Sub: '__sub__',
   (...)
    352     ast.BitXor: '__xor__',
    353 }[type(node.op)]

File ~/src/sd/sd/lib/python3.11/site-packages/triton/compiler.py:855, in CodeGenerator.visit(self, node)
    854 warnings.simplefilter("ignore", PendingDeprecationWarning)  # python 3.8
--> 855 return super().visit(node)

File /usr/lib/python3.11/ast.py:418, in NodeVisitor.visit(self, node)
    417 visitor = getattr(self, method, self.generic_visit)
--> 418 return visitor(node)

File ~/src/sd/sd/lib/python3.11/site-packages/triton/compiler.py:797, in CodeGenerator.visit_Call(self, node)
    795 if (hasattr(fn, '__self__') and self.is_triton_tensor(fn.__self__)) \
    796         or impl.is_builtin(fn):
--> 797     return fn(*args, _builder=self.builder, **kws)
    798 if fn in self.builtins.values():

File ~/src/sd/sd/lib/python3.11/site-packages/triton/impl/base.py:22, in builtin.<locals>.wrapper(*args, **kwargs)
     18     raise ValueError(
     19         "Did you forget to add 

@triton
	.jit ? "
     20         "(`_builder` argument must be provided outside of JIT functions.)"
     21     )
---> 22 return fn(*args, **kwargs)

TypeError: dot() got an unexpected keyword argument 'trans_b'

The above exception was the direct cause of the following exception:

CompilationError                          Traceback (most recent call last)
Cell In[3], line 8
      4 mlm = AutoModelForMaskedLM.from_pretrained('mosaicml/mosaic-bert-base', trust_remote_code=True, revision='24512df').cuda()
      6 classifier = pipeline('fill-mask', model=mlm, tokenizer=tokenizer,device="cuda:0")
----> 8 classifier("I [MASK] to the store yesterday.")

File ~/src/sd/sd/lib/python3.11/site-packages/transformers/pipelines/fill_mask.py:239, in FillMaskPipeline.__call__(self, inputs, *args, **kwargs)
    217 def __call__(self, inputs, *args, **kwargs):
    218     """
    219     Fill the masked token in the text(s) given as inputs.
    220 
   (...)
    237         - **token_str** (`str`) -- The predicted token (to replace the masked one).
    238     """
--> 239     outputs = super().__call__(inputs, **kwargs)
    240     if isinstance(inputs, list) and len(inputs) == 1:
    241         return outputs[0]

File ~/src/sd/sd/lib/python3.11/site-packages/transformers/pipelines/base.py:1118, in Pipeline.__call__(self, inputs, num_workers, batch_size, *args, **kwargs)
   1110     return next(
   1111         iter(
   1112             self.get_iterator(
   (...)
   1115         )
   1116     )
   1117 else:
-> 1118     return self.run_single(inputs, preprocess_params, forward_params, postprocess_params)

File ~/src/sd/sd/lib/python3.11/site-packages/transformers/pipelines/base.py:1125, in Pipeline.run_single(self, inputs, preprocess_params, forward_params, postprocess_params)
   1123 def run_single(self, inputs, preprocess_params, forward_params, postprocess_params):
   1124     model_inputs = self.preprocess(inputs, **preprocess_params)
-> 1125     model_outputs = self.forward(model_inputs, **forward_params)
   1126     outputs = self.postprocess(model_outputs, **postprocess_params)
   1127     return outputs

File ~/src/sd/sd/lib/python3.11/site-packages/transformers/pipelines/base.py:1024, in Pipeline.forward(self, model_inputs, **forward_params)
   1022     with inference_context():
   1023         model_inputs = self._ensure_tensor_on_device(model_inputs, device=self.device)
-> 1024         model_outputs = self._forward(model_inputs, **forward_params)
   1025         model_outputs = self._ensure_tensor_on_device(model_outputs, device=torch.device("cpu"))
   1026 else:

File ~/src/sd/sd/lib/python3.11/site-packages/transformers/pipelines/fill_mask.py:101, in FillMaskPipeline._forward(self, model_inputs)
    100 def _forward(self, model_inputs):
--> 101     model_outputs = self.model(**model_inputs)
    102     model_outputs["input_ids"] = model_inputs["input_ids"]
    103     return model_outputs

File ~/src/sd/sd/lib/python3.11/site-packages/torch/nn/modules/module.py:1501, in Module._call_impl(self, *args, **kwargs)
   1496 # If we don't have any hooks, we want to skip the rest of the logic in
   1497 # this function, and just call forward.
   1498 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1499         or _global_backward_pre_hooks or _global_backward_hooks
   1500         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1501     return forward_call(*args, **kwargs)
   1502 # Do not call functions when jit is used
   1503 full_backward_hooks, non_full_backward_hooks = [], []

File ~/.cache/huggingface/modules/transformers_modules/mosaicml/mosaic-bert-base/fcc434c97e2d475d5dd1a69fca9f734af7a41772/bert_layers.py:850, in BertForMaskedLM.forward(self, input_ids, attention_mask, token_type_ids, position_ids, head_mask, inputs_embeds, encoder_hidden_states, encoder_attention_mask, labels, output_attentions, output_hidden_states, return_dict)
    846     masked_tokens_mask = labels > 0
    848 return_dict = return_dict if return_dict is not None else self.config.use_return_dict
--> 850 outputs = self.bert(
    851     input_ids,
    852     attention_mask=attention_mask,
    853     token_type_ids=token_type_ids,
    854     position_ids=position_ids,
    855     head_mask=head_mask,
    856     inputs_embeds=inputs_embeds,
    857     encoder_hidden_states=encoder_hidden_states,
    858     encoder_attention_mask=encoder_attention_mask,
    859     output_attentions=output_attentions,
    860     output_hidden_states=output_hidden_states,
    861     return_dict=return_dict,
    862     masked_tokens_mask=masked_tokens_mask,
    863 )
    865 sequence_output = outputs[0]
    866 prediction_scores = self.cls(sequence_output)

File ~/src/sd/sd/lib/python3.11/site-packages/torch/nn/modules/module.py:1501, in Module._call_impl(self, *args, **kwargs)
   1496 # If we don't have any hooks, we want to skip the rest of the logic in
   1497 # this function, and just call forward.
   1498 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1499         or _global_backward_pre_hooks or _global_backward_hooks
   1500         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1501     return forward_call(*args, **kwargs)
   1502 # Do not call functions when jit is used
   1503 full_backward_hooks, non_full_backward_hooks = [], []

File ~/.cache/huggingface/modules/transformers_modules/mosaicml/mosaic-bert-base/fcc434c97e2d475d5dd1a69fca9f734af7a41772/bert_layers.py:669, in BertModel.forward(self, input_ids, token_type_ids, attention_mask, position_ids, output_all_encoded_layers, masked_tokens_mask, **kwargs)
    666     first_col_mask[:, 0] = True
    667     subset_mask = masked_tokens_mask | first_col_mask
--> 669 encoder_outputs = self.encoder(
    670     embedding_output,
    671     attention_mask,
    672     output_all_encoded_layers=output_all_encoded_layers,
    673     subset_mask=subset_mask)
    675 if masked_tokens_mask is None:
    676     sequence_output = encoder_outputs[-1]

File ~/src/sd/sd/lib/python3.11/site-packages/torch/nn/modules/module.py:1501, in Module._call_impl(self, *args, **kwargs)
   1496 # If we don't have any hooks, we want to skip the rest of the logic in
   1497 # this function, and just call forward.
   1498 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1499         or _global_backward_pre_hooks or _global_backward_hooks
   1500         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1501     return forward_call(*args, **kwargs)
   1502 # Do not call functions when jit is used
   1503 full_backward_hooks, non_full_backward_hooks = [], []

File ~/.cache/huggingface/modules/transformers_modules/mosaicml/mosaic-bert-base/fcc434c97e2d475d5dd1a69fca9f734af7a41772/bert_layers.py:507, in BertEncoder.forward(self, hidden_states, attention_mask, output_all_encoded_layers, subset_mask)
    505 if subset_mask is None:
    506     for layer_module in self.layer:
--> 507         hidden_states = layer_module(hidden_states,
    508                                      cu_seqlens,
    509                                      seqlen,
    510                                      None,
    511                                      indices,
    512                                      attn_mask=attention_mask,
    513                                      bias=alibi_attn_mask)
    514         if output_all_encoded_layers:
    515             all_encoder_layers.append(hidden_states)

File ~/src/sd/sd/lib/python3.11/site-packages/torch/nn/modules/module.py:1501, in Module._call_impl(self, *args, **kwargs)
   1496 # If we don't have any hooks, we want to skip the rest of the logic in
   1497 # this function, and just call forward.
   1498 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1499         or _global_backward_pre_hooks or _global_backward_hooks
   1500         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1501     return forward_call(*args, **kwargs)
   1502 # Do not call functions when jit is used
   1503 full_backward_hooks, non_full_backward_hooks = [], []

File ~/.cache/huggingface/modules/transformers_modules/mosaicml/mosaic-bert-base/fcc434c97e2d475d5dd1a69fca9f734af7a41772/bert_layers.py:388, in BertLayer.forward(self, hidden_states, cu_seqlens, seqlen, subset_idx, indices, attn_mask, bias)
    366 def forward(
    367     self,
    368     hidden_states: torch.Tensor,
   (...)
    374     bias: Optional[torch.Tensor] = None,
    375 ) -> torch.Tensor:
    376     """Forward pass for a BERT layer, including both attention and MLP.
    377 
    378     Args:
   (...)
    386         bias: None or (batch, heads, max_seqlen_in_batch, max_seqlen_in_batch)
    387     """
--> 388     attention_output = self.attention(hidden_states, cu_seqlens, seqlen,
    389                                       subset_idx, indices, attn_mask, bias)
    390     layer_output = self.mlp(attention_output)
    391     return layer_output

File ~/src/sd/sd/lib/python3.11/site-packages/torch/nn/modules/module.py:1501, in Module._call_impl(self, *args, **kwargs)
   1496 # If we don't have any hooks, we want to skip the rest of the logic in
   1497 # this function, and just call forward.
   1498 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1499         or _global_backward_pre_hooks or _global_backward_hooks
   1500         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1501     return forward_call(*args, **kwargs)
   1502 # Do not call functions when jit is used
   1503 full_backward_hooks, non_full_backward_hooks = [], []

File ~/.cache/huggingface/modules/transformers_modules/mosaicml/mosaic-bert-base/fcc434c97e2d475d5dd1a69fca9f734af7a41772/bert_layers.py:301, in BertUnpadAttention.forward(self, input_tensor, cu_seqlens, max_s, subset_idx, indices, attn_mask, bias)
    279 def forward(
    280     self,
    281     input_tensor: torch.Tensor,
   (...)
    287     bias: Optional[torch.Tensor] = None,
    288 ) -> torch.Tensor:
    289     """Forward pass for scaled self-attention without padding.
    290 
    291     Arguments:
   (...)
    299         bias: None or (batch, heads, max_seqlen_in_batch, max_seqlen_in_batch)
    300     """
--> 301     self_output = self.self(input_tensor, cu_seqlens, max_s, indices,
    302                             attn_mask, bias)
    303     if subset_idx is not None:
    304         return self.output(index_first_axis(self_output, subset_idx),
    305                            index_first_axis(input_tensor, subset_idx))

File ~/src/sd/sd/lib/python3.11/site-packages/torch/nn/modules/module.py:1501, in Module._call_impl(self, *args, **kwargs)
   1496 # If we don't have any hooks, we want to skip the rest of the logic in
   1497 # this function, and just call forward.
   1498 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1499         or _global_backward_pre_hooks or _global_backward_hooks
   1500         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1501     return forward_call(*args, **kwargs)
   1502 # Do not call functions when jit is used
   1503 full_backward_hooks, non_full_backward_hooks = [], []

File ~/.cache/huggingface/modules/transformers_modules/mosaicml/mosaic-bert-base/fcc434c97e2d475d5dd1a69fca9f734af7a41772/bert_layers.py:233, in BertUnpadSelfAttention.forward(self, hidden_states, cu_seqlens, max_seqlen_in_batch, indices, attn_mask, bias)
    231 bias_dtype = bias.dtype
    232 bias = bias.to(torch.float16)
--> 233 attention = flash_attn_qkvpacked_func(qkv, bias)
    234 attention = attention.to(orig_dtype)
    235 bias = bias.to(bias_dtype)

File ~/src/sd/sd/lib/python3.11/site-packages/torch/autograd/function.py:506, in Function.apply(cls, *args, **kwargs)
    503 if not torch._C._are_functorch_transforms_active():
    504     # See NOTE: [functorch vjp and autograd interaction]
    505     args = _functorch.utils.unwrap_dead_wrappers(args)
--> 506     return super().apply(*args, **kwargs)  # type: ignore[misc]
    508 if cls.setup_context == _SingleLevelFunction.setup_context:
    509     raise RuntimeError(
    510         'In order to use an autograd.Function with functorch transforms '
    511         '(vmap, grad, jvp, jacrev, ...), it must override the setup_context '
    512         'staticmethod. For more details, please see '
    513         'https://pytorch.org/docs/master/notes/extending.func.html')

File ~/.cache/huggingface/modules/transformers_modules/mosaicml/mosaic-bert-base/fcc434c97e2d475d5dd1a69fca9f734af7a41772/flash_attn_triton.py:1021, in _FlashAttnQKVPackedFunc.forward(ctx, qkv, bias, causal, softmax_scale)
   1019 if qkv.stride(-1) != 1:
   1020     qkv = qkv.contiguous()
-> 1021 o, lse, ctx.softmax_scale = _flash_attn_forward(
   1022     qkv[:, :, 0],
   1023     qkv[:, :, 1],
   1024     qkv[:, :, 2],
   1025     bias=bias,
   1026     causal=causal,
   1027     softmax_scale=softmax_scale)
   1028 ctx.save_for_backward(qkv, o, lse, bias)
   1029 ctx.causal = causal

File ~/.cache/huggingface/modules/transformers_modules/mosaicml/mosaic-bert-base/fcc434c97e2d475d5dd1a69fca9f734af7a41772/flash_attn_triton.py:826, in _flash_attn_forward(q, k, v, bias, causal, softmax_scale)
    823 # BLOCK = 128
    824 # num_warps = 4 if d <= 64 else 8
    825 grid = lambda META: (triton.cdiv(seqlen_q, META['BLOCK_M']), batch * nheads)
--> 826 _fwd_kernel[grid](  # type: ignore
    827     q,
    828     k,
    829     v,
    830     bias,
    831     o,
    832     lse,
    833     tmp,
    834     softmax_scale,
    835     q.stride(0),
    836     q.stride(2),
    837     q.stride(1),
    838     k.stride(0),
    839     k.stride(2),
    840     k.stride(1),
    841     v.stride(0),
    842     v.stride(2),
    843     v.stride(1),
    844     *bias_strides,
    845     o.stride(0),
    846     o.stride(2),
    847     o.stride(1),
    848     nheads,
    849     seqlen_q,
    850     seqlen_k,
    851     seqlen_q_rounded,
    852     d,
    853     seqlen_q // 32,
    854     seqlen_k // 32,  # key for triton cache (limit number of compilations)
    855     # Can't use kwargs here because triton autotune expects key to be args, not kwargs
    856     # IS_CAUSAL=causal, BLOCK_HEADDIM=d,
    857     bias_type,
    858     causal,
    859     BLOCK_HEADDIM,
    860     # BLOCK_M=BLOCK, BLOCK_N=BLOCK,
    861     # num_warps=num_warps,
    862     # num_stages=1,
    863 )
    864 return o, lse, softmax_scale

File ~/src/sd/sd/lib/python3.11/site-packages/triton/runtime/autotuner.py:90, in Autotuner.run(self, *args, **kwargs)
     88 if config.pre_hook is not None:
     89     config.pre_hook(self.nargs)
---> 90 return self.fn.run(*args, num_warps=config.num_warps, num_stages=config.num_stages, **kwargs, **config.kwargs)

File ~/src/sd/sd/lib/python3.11/site-packages/triton/runtime/autotuner.py:199, in Heuristics.run(self, *args, **kwargs)
    197 for v, heur in self.values.items():
    198     kwargs[v] = heur({**dict(zip(self.arg_names, args)), **kwargs})
--> 199 return self.fn.run(*args, **kwargs)

File <string>:41, in _fwd_kernel(Q, K, V, Bias, Out, Lse, TMP, softmax_scale, stride_qb, stride_qh, stride_qm, stride_kb, stride_kh, stride_kn, stride_vb, stride_vh, stride_vn, stride_bb, stride_bh, stride_bm, stride_ob, stride_oh, stride_om, nheads, seqlen_q, seqlen_k, seqlen_q_rounded, headdim, CACHE_KEY_SEQLEN_Q, CACHE_KEY_SEQLEN_K, BIAS_TYPE, IS_CAUSAL, BLOCK_HEADDIM, EVEN_M, EVEN_N, EVEN_HEADDIM, BLOCK_M, BLOCK_N, grid, num_warps, num_stages, extern_libs, stream, warmup)

File ~/src/sd/sd/lib/python3.11/site-packages/triton/compiler.py:1621, in compile(fn, **kwargs)
   1619     next_module = parse(path)
   1620 else:
-> 1621     next_module = compile(module)
   1622     fn_cache_manager.put(next_module, f"{name}.{ir}")
   1623 if os.path.exists(path):

File ~/src/sd/sd/lib/python3.11/site-packages/triton/compiler.py:1550, in compile.<locals>.<lambda>(src)
   1545 extern_libs = kwargs.get("extern_libs", dict())
   1546 # build compilation stages
   1547 stages = {
   1548     "ast": (lambda path: fn, None),
   1549     "ttir": (lambda path: parse_mlir_module(path, context),
-> 1550              lambda src: ast_to_ttir(src, signature, configs[0], constants)),
   1551     "ttgir": (lambda path: parse_mlir_module(path, context),
   1552               lambda src: ttir_to_ttgir(src, num_warps, num_stages, capability)),
   1553     "llir": (lambda path: Path(path).read_text(),
   1554              lambda src: ttgir_to_llir(src, extern_libs, capability)),
   1555     "ptx": (lambda path: Path(path).read_text(),
   1556             lambda src: llir_to_ptx(src, capability)),
   1557     "cubin": (lambda path: Path(path).read_bytes(),
   1558               lambda src: ptx_to_cubin(src, capability))
   1559 }
   1560 # find out the signature of the function
   1561 if isinstance(fn, triton.runtime.JITFunction):

File ~/src/sd/sd/lib/python3.11/site-packages/triton/compiler.py:962, in ast_to_ttir(fn, signature, specialization, constants)
    961 def ast_to_ttir(fn, signature, specialization, constants):
--> 962     mod, _ = build_triton_ir(fn, signature, specialization, constants)
    963     return optimize_triton_ir(mod)

File ~/src/sd/sd/lib/python3.11/site-packages/triton/compiler.py:942, in build_triton_ir(fn, signature, specialization, constants)
    940     if node is None or isinstance(e, (NotImplementedError, CompilationError)):
    941         raise e
--> 942     raise CompilationError(fn.src, node) from e
    943 ret = generator.module
    944 # module takes ownership of the context

CompilationError: at 114:24:
def _fwd_kernel(
    Q,
    K,
    V,
    Bias,
    Out,
    Lse,
    TMP,  # NOTE: TMP is a scratchpad buffer to workaround a compiler bug
    softmax_scale,
    stride_qb,
    stride_qh,
    stride_qm,
    stride_kb,
    stride_kh,
    stride_kn,
    stride_vb,
    stride_vh,
    stride_vn,
    stride_bb,
    stride_bh,
    stride_bm,
    stride_ob,
    stride_oh,
    stride_om,
    nheads,
    seqlen_q,
    seqlen_k,
    seqlen_q_rounded,
    headdim,
    CACHE_KEY_SEQLEN_Q,
    CACHE_KEY_SEQLEN_K,
    BIAS_TYPE: tl.constexpr,
    IS_CAUSAL: tl.constexpr,
    BLOCK_HEADDIM: tl.constexpr,
    EVEN_M: tl.constexpr,
    EVEN_N: tl.constexpr,
    EVEN_HEADDIM: tl.constexpr,
    BLOCK_M: tl.constexpr,
    BLOCK_N: tl.constexpr,
):
    start_m = tl.program_id(0)
    off_hb = tl.program_id(1)
    off_b = off_hb // nheads
    off_h = off_hb % nheads
    # off_b = tl.program_id(1)
    # off_h = tl.program_id(2)
    # off_hb = off_b * nheads + off_h
    # initialize offsets
    offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
    offs_n = tl.arange(0, BLOCK_N)
    offs_d = tl.arange(0, BLOCK_HEADDIM)
    # Initialize pointers to Q, K, V
    # Adding parenthesis around indexing might use int32 math instead of int64 math?
    # https://github.com/openai/triton/issues/741
    # I'm seeing a tiny bit of difference (5-7us)
    q_ptrs = Q + off_b * stride_qb + off_h * stride_qh + (
        offs_m[:, None] * stride_qm + offs_d[None, :])
    k_ptrs = K + off_b * stride_kb + off_h * stride_kh + (
        offs_n[:, None] * stride_kn + offs_d[None, :])
    v_ptrs = V + off_b * stride_vb + off_h * stride_vh + (
        offs_n[:, None] * stride_vn + offs_d[None, :])
    if BIAS_TYPE == 'vector':
        b_ptrs = Bias + off_b * stride_bb + off_h * stride_bh + offs_n
    elif BIAS_TYPE == 'matrix':
        b_ptrs = Bias + off_b * stride_bb + off_h * stride_bh + (
            offs_m[:, None] * stride_bm + offs_n[None, :])
    else:
        raise ValueError("BIAS_TYPE must be one of {'vector', 'matrix'}")
    # initialize pointer to m and l
    t_ptrs = TMP + off_hb * seqlen_q_rounded + offs_m
    lse_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float('inf')
    m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float('inf')
    acc_o = tl.zeros([BLOCK_M, BLOCK_HEADDIM], dtype=tl.float32)
    # load q: it will stay in SRAM throughout
    # [2022-10-30] TD: Triton bug - in the case of EVEN_M=True and EVEN_N=False, if we just call
    # tl.load(q_ptrs), we get the wrong output!
    if EVEN_M & EVEN_N:
        if EVEN_HEADDIM:
            q = tl.load(q_ptrs)
        else:
            q = tl.load(q_ptrs, mask=offs_d[None, :] < headdim, other=0.0)
    else:
        if EVEN_HEADDIM:
            q = tl.load(q_ptrs, mask=offs_m[:, None] < seqlen_q, other=0.0)
        else:
            q = tl.load(q_ptrs,
                        mask=(offs_m[:, None] < seqlen_q) &
                        (offs_d[None, :] < headdim),
                        other=0.0)
    # loop over k, v and update accumulator
    end_n = seqlen_k if not IS_CAUSAL else tl.minimum(
        (start_m + 1) * BLOCK_M, seqlen_k)
    for start_n in range(0, end_n, BLOCK_N):
        start_n = tl.multiple_of(start_n, BLOCK_N)
        # -- compute qk ----
        if EVEN_N & EVEN_M:  # If we just do "if EVEN_N", there seems to be some race condition
            if EVEN_HEADDIM:
                k = tl.load(k_ptrs + start_n * stride_kn)
            else:
                k = tl.load(k_ptrs + start_n * stride_kn,
                            mask=offs_d[None, :] < headdim,
                            other=0.0)
        else:
            if EVEN_HEADDIM:
                k = tl.load(k_ptrs + start_n * stride_kn,
                            mask=(start_n + offs_n)[:, None] < seqlen_k,
                            other=0.0)
            else:
                k = tl.load(k_ptrs + start_n * stride_kn,
                            mask=((start_n + offs_n)[:, None] < seqlen_k) &
                            (offs_d[None, :] < headdim),
                            other=0.0)
        qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
        qk += tl.dot(q, k, trans_b=True)

I use

In [8]: sys.version_info
Out[8]: sys.version_info(major=3, minor=11, micro=3, releaselevel='final', serial=0)

In [9]: torch.__version__
Out[9]: '2.0.1+cu117'

In [10]: import triton

In [11]: triton.__version__
Out[12]: '2.0.0'

In [13]: import triton.language as tl

In [14]: tl.__version__
---------------------------------------------------------------------------
AttributeError                            Traceback (most recent call last)
Cell In[14], line 1
----> 1 tl.__version__

AttributeError: module 'triton.language' has no attribute '__version__'

In [15]: tl.dot
Out[15]: <function triton.language.core.dot(input, other, allow_tf32=True, _builder=None)>

Do I need special revision= string to make it work? Did triton language have breaking changes?

Maykeye changed discussion title from How to actually run the model without get run time errors? to How to actually run the model without getting run-time errors?

This is a great question. Has anyone managed to successfully resolve it yet?
Being able to run the example code without errors would certainly increase confidence in the model immensely.

This is a great question. Has anyone managed to successfully resolve it yet?
Being able to run the example code without errors would certainly increase confidence in the model immensely.

I've managed to run it after changing triton versin

pip uninstall triton
pip install --no-deps triton==2.0.0.dev20221202 

I've used --no-deps as otherwise it wanted to downgrade torch from 2.0.1 to 2.0.0. (No, thank you very much)

Here's a fully working example to run from directory of downloaded model (hence os.getcwd() - you can't use from_pretrained('.') in this case as it causes weird errors down the line)

$ cat runme.py 
import os 
import torch
from transformers import AutoModelForMaskedLM, BertTokenizer, pipeline

mlm = AutoModelForMaskedLM.from_pretrained(os.getcwd(), trust_remote_code=True, torch_dtype=torch.bfloat16).cuda()
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
classifier = pipeline('fill-mask', model=mlm, tokenizer=tokenizer, device=0)
print(classifier("I [MASK] to the store yesterday."))

$ python runme.py 
[{'score': 0.8977681398391724, 'token': 2253, 'token_str': 'went', 'sequence': 'i went to the store yesterday.'}, {'score': 0.02546772174537182, 'token': 2234, 'token_str': 'came', 'sequence': 'i came to the store yesterday.'}, {'score': 0.021113483235239983, 'token': 2939, 'token_str': 'walked', 'sequence': 'i walked to the store yesterday.'}, {'score': 0.013631888665258884, 'token': 2288, 'token_str': 'got', 'sequence': 'i got to the store yesterday.'}, {'score': 0.00997330341488123, 'token': 5225, 'token_str': 'drove', 'sequence': 'i drove to the store yesterday.'}]

Things I also tried:

  • replacing all tl.dot(A, B, trans_a=True) with tl.dot(tl.trans(A), B), but either I was not accurate or it's too compute-extensive: python either hanged or I lost patience.

  • throwing away flash attention and using torch's scaled_dot_product_attention. I couldn't figure out how to massage parameters into correct shape

  • remove local flash_attention_triton and import one from the flash_attention package. It dumped a giant error log, but that's where I noticed that it was using not triton 2.0.0, but 2.0.0dev

After replacing triton version everything works.

Magical version string was taken from the python's flash-attention package

This worked beautifully! Thank you so much for sharing this solution.
Just out of curiosity, if you did not download the model and use os.getcwd(), were you actually receiving errors or was the model simply producing nonsensical inferences?

Sign up or log in to comment