Text Generation
Transformers
PyTorch
mpt
Composer
MosaicML
llm-foundry
custom_code
text-generation-inference

KeyError in triton implementation

#25
by datacow - opened

I'm loading in the triton implementation of the model using a custom device map and trying to generate an output as follows (to be clear, I have no issues with the torch implementation):

torch_dtype = torch.bfloat16

config = AutoConfig.from_pretrained(
  'mosaicml/mpt-7b',
    trust_remote_code=True
)
config.attn_config['attn_impl'] = 'triton'
config.update({"max_seq_len": max_len})
config.update({"torch_dtype": torch_dtype})

with open('MPT_device_map.pkl', 'rb') as f:
    dm = pickle.load(f)

model = AutoModelForCausalLM.from_pretrained(
    'mosaicml/mpt-7b-instruct',
    torch_dtype=torch_dtype,
    trust_remote_code=True,
    device_map=dm,
    config=config,
    local_files_only=True
)

tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-neox-20b", padding_side="left")

tokenizer.pad_token = tokenizer.eos_token

inputs = tokenizer(text, return_tensors="pt", padding=True).input_ids.to(device)

streamer = TextStreamer(tokenizer)

with torch.inference_mode():
    generate_ids = model.generate(inputs, **params, streamer=streamer)
    generate_ids = generate_ids[:,inputs[0].shape[-1]:]
    output = tokenizer.batch_decode(generate_ids, skip_special_tokens=True)[0]

And I'm getting the following error:

/usr/bin/ld: skipping incompatible /usr/lib/libcuda.so when searching for -lcuda
---------------------------------------------------------------------------
KeyError                                  Traceback (most recent call last)
File <string>:21, in _fwd_kernel(Q, K, V, Bias, Out, Lse, TMP, softmax_scale, stride_qb, stride_qh, stride_qm, stride_kb, stride_kh, stride_kn, stride_vb, stride_vh, stride_vn, stride_bb, stride_bh, stride_bm, stride_ob, stride_oh, stride_om, nheads, seqlen_q, seqlen_k, seqlen_q_rounded, headdim, CACHE_KEY_SEQLEN_Q, CACHE_KEY_SEQLEN_K, BIAS_TYPE, IS_CAUSAL, BLOCK_HEADDIM, EVEN_M, EVEN_N, EVEN_HEADDIM, BLOCK_M, BLOCK_N, grid, num_warps, num_stages, extern_libs, stream, warmup)

KeyError: ('2-.-0-.-0-d82511111ad128294e9d31a6ac684238-2b0c5161c53c71b37ae20a9996ee4bb8-c1f92808b4e4644c1732e8338187ac87-d962222789c30252d492a16cca3bf467-12f7ac1ca211e037f62a7c0c323d9990-5c5e32ff210f3b7f56c98ca29917c25e-06f0df2d61979d629033f4a22eff5198-0dd03b0bd512a184b3512b278d9dfa59-d35ab04ae841e2714a253c523530b071', (torch.bfloat16, torch.bfloat16, torch.bfloat16, torch.bfloat16, torch.bfloat16, torch.float32, torch.float32, 'fp32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32'), ('vector', True, 128, False, False, True, 128, 128), (True, True, True, True, True, True, True, (False,), (True, False), (True, False), (True, False), (True, False), (True, False), (True, False), (True, False), (True, False), (True, False), (True, False), (True, False), (True, False), (True, False), (True, False), (True, False), (True, False), (True, False), (True, False), (True, False), (True, False), (True, False), (True, False)))

During handling of the above exception, another exception occurred:

RuntimeError                              Traceback (most recent call last)
Cell In[41], line 8
      4 with torch.inference_mode():
      5 #     res = nlp(text, max_new_tokens=mnt, min_new_tokens=1, return_full_text=False)
      6 #     inputs = input_map(inputs)
      7     st = time.time()
----> 8     generate_ids = model.generate(inputs, **params, streamer=streamer)
      9 #     generate_ids = model.module.generate(**inputs, **params, streamer=streamer)
     10     tt = time.time() - st

File ~/anaconda3/envs/llm/lib/python3.10/site-packages/torch/autograd/grad_mode.py:27, in _DecoratorContextManager.__call__.<locals>.decorate_context(*args, **kwargs)
     24 @functools.wraps(func)
     25 def decorate_context(*args, **kwargs):
     26     with self.clone():
---> 27         return func(*args, **kwargs)

File ~/anaconda3/envs/llm/lib/python3.10/site-packages/transformers/generation/utils.py:1565, in GenerationMixin.generate(self, inputs, generation_config, logits_processor, stopping_criteria, prefix_allowed_tokens_fn, synced_gpus, assistant_model, streamer, **kwargs)
   1557     input_ids, model_kwargs = self._expand_inputs_for_generation(
   1558         input_ids=input_ids,
   1559         expand_size=generation_config.num_return_sequences,
   1560         is_encoder_decoder=self.config.is_encoder_decoder,
   1561         **model_kwargs,
   1562     )
   1564     # 13. run sample
-> 1565     return self.sample(
   1566         input_ids,
   1567         logits_processor=logits_processor,
   1568         logits_warper=logits_warper,
   1569         stopping_criteria=stopping_criteria,
   1570         pad_token_id=generation_config.pad_token_id,
   1571         eos_token_id=generation_config.eos_token_id,
   1572         output_scores=generation_config.output_scores,
   1573         return_dict_in_generate=generation_config.return_dict_in_generate,
   1574         synced_gpus=synced_gpus,
   1575         streamer=streamer,
   1576         **model_kwargs,
   1577     )
   1579 elif is_beam_gen_mode:
   1580     if generation_config.num_return_sequences > generation_config.num_beams:

File ~/anaconda3/envs/llm/lib/python3.10/site-packages/transformers/generation/utils.py:2612, in GenerationMixin.sample(self, input_ids, logits_processor, stopping_criteria, logits_warper, max_length, pad_token_id, eos_token_id, output_attentions, output_hidden_states, output_scores, return_dict_in_generate, synced_gpus, streamer, **model_kwargs)
   2609 model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
   2611 # forward pass to get next token
-> 2612 outputs = self(
   2613     **model_inputs,
   2614     return_dict=True,
   2615     output_attentions=output_attentions,
   2616     output_hidden_states=output_hidden_states,
   2617 )
   2619 if synced_gpus and this_peer_finished:
   2620     continue  # don't waste resources running the code we don't need

File ~/anaconda3/envs/llm/lib/python3.10/site-packages/torch/nn/modules/module.py:1194, in Module._call_impl(self, *input, **kwargs)
   1190 # If we don't have any hooks, we want to skip the rest of the logic in
   1191 # this function, and just call forward.
   1192 if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
   1193         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1194     return forward_call(*input, **kwargs)
   1195 # Do not call functions when jit is used
   1196 full_backward_hooks, non_full_backward_hooks = [], []

File ~/anaconda3/envs/llm/lib/python3.10/site-packages/accelerate/hooks.py:165, in add_hook_to_module.<locals>.new_forward(*args, **kwargs)
    163         output = old_forward(*args, **kwargs)
    164 else:
--> 165     output = old_forward(*args, **kwargs)
    166 return module._hf_hook.post_forward(module, output)

File ~/.cache/huggingface/modules/transformers_modules/mosaicml/mpt-7b/d8304854d4877849c3c0a78f3469512a84419e84/modeling_mpt.py:237, in MPTForCausalLM.forward(self, input_ids, past_key_values, attention_mask, prefix_mask, sequence_id, labels, return_dict, output_attentions, output_hidden_states, use_cache)
    235 return_dict = return_dict if return_dict is not None else self.config.return_dict
    236 use_cache = use_cache if use_cache is not None else self.config.use_cache
--> 237 outputs = self.transformer(input_ids=input_ids, past_key_values=past_key_values, attention_mask=attention_mask, prefix_mask=prefix_mask, sequence_id=sequence_id, return_dict=return_dict, output_attentions=output_attentions, output_hidden_states=output_hidden_states, use_cache=use_cache)
    238 logits = F.linear(outputs.last_hidden_state, self.transformer.wte.weight)
    239 if self.logit_scale is not None:

File ~/anaconda3/envs/llm/lib/python3.10/site-packages/torch/nn/modules/module.py:1194, in Module._call_impl(self, *input, **kwargs)
   1190 # If we don't have any hooks, we want to skip the rest of the logic in
   1191 # this function, and just call forward.
   1192 if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
   1193         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1194     return forward_call(*input, **kwargs)
   1195 # Do not call functions when jit is used
   1196 full_backward_hooks, non_full_backward_hooks = [], []

File ~/.cache/huggingface/modules/transformers_modules/mosaicml/mpt-7b/d8304854d4877849c3c0a78f3469512a84419e84/modeling_mpt.py:183, in MPTModel.forward(self, input_ids, past_key_values, attention_mask, prefix_mask, sequence_id, return_dict, output_attentions, output_hidden_states, use_cache)
    181     all_hidden_states = all_hidden_states + (x,)
    182 past_key_value = past_key_values[b_idx] if past_key_values is not None else None
--> 183 (x, past_key_value) = block(x, past_key_value=past_key_value, attn_bias=attn_bias, attention_mask=attention_mask, is_causal=self.is_causal)
    184 if past_key_values is not None:
    185     past_key_values[b_idx] = past_key_value

File ~/anaconda3/envs/llm/lib/python3.10/site-packages/torch/nn/modules/module.py:1194, in Module._call_impl(self, *input, **kwargs)
   1190 # If we don't have any hooks, we want to skip the rest of the logic in
   1191 # this function, and just call forward.
   1192 if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
   1193         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1194     return forward_call(*input, **kwargs)
   1195 # Do not call functions when jit is used
   1196 full_backward_hooks, non_full_backward_hooks = [], []

File ~/anaconda3/envs/llm/lib/python3.10/site-packages/accelerate/hooks.py:165, in add_hook_to_module.<locals>.new_forward(*args, **kwargs)
    163         output = old_forward(*args, **kwargs)
    164 else:
--> 165     output = old_forward(*args, **kwargs)
    166 return module._hf_hook.post_forward(module, output)

File ~/.cache/huggingface/modules/transformers_modules/mosaicml/mpt-7b/d8304854d4877849c3c0a78f3469512a84419e84/blocks.py:36, in MPTBlock.forward(self, x, past_key_value, attn_bias, attention_mask, is_causal)
     34 def forward(self, x: torch.Tensor, past_key_value: Optional[Tuple[torch.Tensor]]=None, attn_bias: Optional[torch.Tensor]=None, attention_mask: Optional[torch.ByteTensor]=None, is_causal: bool=True) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor]]]:
     35     a = self.norm_1(x)
---> 36     (b, _, past_key_value) = self.attn(a, past_key_value=past_key_value, attn_bias=attn_bias, attention_mask=attention_mask, is_causal=is_causal)
     37     x = x + self.resid_attn_dropout(b)
     38     m = self.norm_2(x)

File ~/anaconda3/envs/llm/lib/python3.10/site-packages/torch/nn/modules/module.py:1194, in Module._call_impl(self, *input, **kwargs)
   1190 # If we don't have any hooks, we want to skip the rest of the logic in
   1191 # this function, and just call forward.
   1192 if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
   1193         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1194     return forward_call(*input, **kwargs)
   1195 # Do not call functions when jit is used
   1196 full_backward_hooks, non_full_backward_hooks = [], []

File ~/anaconda3/envs/llm/lib/python3.10/site-packages/accelerate/hooks.py:165, in add_hook_to_module.<locals>.new_forward(*args, **kwargs)
    163         output = old_forward(*args, **kwargs)
    164 else:
--> 165     output = old_forward(*args, **kwargs)
    166 return module._hf_hook.post_forward(module, output)

File ~/.cache/huggingface/modules/transformers_modules/mosaicml/mpt-7b/d8304854d4877849c3c0a78f3469512a84419e84/attention.py:171, in MultiheadAttention.forward(self, x, past_key_value, attn_bias, attention_mask, is_causal, needs_weights)
    169 if attn_bias is not None:
    170     attn_bias = attn_bias[:, :, -query.size(1):, -key.size(1):]
--> 171 (context, attn_weights) = self.attn_fn(query, key, value, self.n_heads, softmax_scale=self.softmax_scale, attn_bias=attn_bias, key_padding_mask=key_padding_mask, is_causal=is_causal, dropout_p=self.attn_dropout_p, training=self.training, needs_weights=needs_weights)
    172 return (self.out_proj(context), attn_weights, past_key_value)

File ~/.cache/huggingface/modules/transformers_modules/mosaicml/mpt-7b/d8304854d4877849c3c0a78f3469512a84419e84/attention.py:111, in triton_flash_attn_fn(query, key, value, n_heads, softmax_scale, attn_bias, key_padding_mask, is_causal, dropout_p, training, needs_weights, multiquery)
    109     value = value.expand(*value.shape[:2], n_heads, value.size(-1))
    110 reset_is_causal = _reset_is_causal(query.size(1), key.size(1), is_causal)
--> 111 attn_output = flash_attn_triton.flash_attn_func(query, key, value, attn_bias, reset_is_causal, softmax_scale)
    112 output = attn_output.view(*attn_output.shape[:2], -1)
    113 return (output, None)

File ~/anaconda3/envs/llm/lib/python3.10/site-packages/flash_attn/flash_attn_triton.py:810, in FlashAttnFunc.forward(ctx, q, k, v, bias, causal, softmax_scale)
    808 # Make sure that the last dimension is contiguous
    809 q, k, v = [x if x.stride(-1) == 1 else x.contiguous() for x in [q, k, v]]
--> 810 o, lse, ctx.softmax_scale = _flash_attn_forward(
    811     q, k, v, bias=bias, causal=causal, softmax_scale=softmax_scale
    812 )
    813 ctx.save_for_backward(q, k, v, o, lse, bias)
    814 ctx.causal = causal

File ~/anaconda3/envs/llm/lib/python3.10/site-packages/flash_attn/flash_attn_triton.py:623, in _flash_attn_forward(q, k, v, bias, causal, softmax_scale)
    621 num_warps = 4 if d <= 64 else 8
    622 grid = lambda META: (triton.cdiv(seqlen_q, META["BLOCK_M"]), batch * nheads)
--> 623 _fwd_kernel[grid](
    624     q, k, v, bias, o,
    625     lse, tmp,
    626     softmax_scale,
    627     q.stride(0), q.stride(2), q.stride(1),
    628     k.stride(0), k.stride(2), k.stride(1),
    629     v.stride(0), v.stride(2), v.stride(1),
    630     *bias_strides,
    631     o.stride(0), o.stride(2), o.stride(1),
    632     nheads, seqlen_q, seqlen_k, seqlen_q_rounded, d,
    633     seqlen_q // 32,  seqlen_k // 32, # key for triton cache (limit number of compilations)
    634     # Can't use kwargs here because triton autotune expects key to be args, not kwargs
    635     # IS_CAUSAL=causal, BLOCK_HEADDIM=d,
    636     bias_type, causal, BLOCK_HEADDIM,
    637     BLOCK_M=BLOCK, BLOCK_N=BLOCK,
    638     num_warps=num_warps,
    639     num_stages=1,
    640 )
    641 return o, lse, softmax_scale

File ~/anaconda3/envs/llm/lib/python3.10/site-packages/triton/runtime/jit.py:106, in KernelInterface.__getitem__.<locals>.launcher(*args, **kwargs)
    105 def launcher(*args, **kwargs):
--> 106     return self.run(*args, grid=grid, **kwargs)

File ~/anaconda3/envs/llm/lib/python3.10/site-packages/triton/runtime/autotuner.py:200, in Heuristics.run(self, *args, **kwargs)
    198 for v, heur in self.values.items():
    199     kwargs[v] = heur({**dict(zip(self.arg_names, args)), **kwargs})
--> 200 return self.fn.run(*args, **kwargs)

File <string>:43, in _fwd_kernel(Q, K, V, Bias, Out, Lse, TMP, softmax_scale, stride_qb, stride_qh, stride_qm, stride_kb, stride_kh, stride_kn, stride_vb, stride_vh, stride_vn, stride_bb, stride_bh, stride_bm, stride_ob, stride_oh, stride_om, nheads, seqlen_q, seqlen_k, seqlen_q_rounded, headdim, CACHE_KEY_SEQLEN_Q, CACHE_KEY_SEQLEN_K, BIAS_TYPE, IS_CAUSAL, BLOCK_HEADDIM, EVEN_M, EVEN_N, EVEN_HEADDIM, BLOCK_M, BLOCK_N, grid, num_warps, num_stages, extern_libs, stream, warmup)

RuntimeError: Triton Error [CUDA]: invalid argument

Any ideas what might be causing this? I'm working with:
Python: 3.10
CUDA: 11.7
triton: 2.0.0.dev20221202
flash-attn: 1.0.3.post0
transformers: 4.29.2
torch: 1.13.1+cu117

Yeah, I'm getting this too

Update + additional context for this error. Was using T4 NVIDIA GPUs for above error. Switched to test on V100 GPUs with same packages/installs, and am now getting something different:

Briefly explain to me what the Reimann Hypothesis 
/usr/bin/ld: skipping incompatible /usr/lib/libcuda.so when searching for -lcuda
is.
â-s�AN’s,
  
/usr/bin/ld: skipping incompatible /usr/lib/libcuda.so when searching for -lcuda
---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
Cell In[46], line 8
      4 with torch.inference_mode():
      5 #     res = nlp(text, max_new_tokens=mnt, min_new_tokens=1, return_full_text=False)
      6 #     inputs = input_map(inputs)
      7     st = time.time()
----> 8     generate_ids = model.generate(inputs, **params, streamer=streamer)
      9 #     generate_ids = model.module.generate(**inputs, **params, streamer=streamer)
     10     tt = time.time() - st

File ~/anaconda3/envs/llm/lib/python3.10/site-packages/torch/autograd/grad_mode.py:27, in _DecoratorContextManager.__call__.<locals>.decorate_context(*args, **kwargs)
     24 @functools.wraps(func)
     25 def decorate_context(*args, **kwargs):
     26     with self.clone():
---> 27         return func(*args, **kwargs)

File ~/anaconda3/envs/llm/lib/python3.10/site-packages/transformers/generation/utils.py:1565, in GenerationMixin.generate(self, inputs, generation_config, logits_processor, stopping_criteria, prefix_allowed_tokens_fn, synced_gpus, assistant_model, streamer, **kwargs)
   1557     input_ids, model_kwargs = self._expand_inputs_for_generation(
   1558         input_ids=input_ids,
   1559         expand_size=generation_config.num_return_sequences,
   1560         is_encoder_decoder=self.config.is_encoder_decoder,
   1561         **model_kwargs,
   1562     )
   1564     # 13. run sample
-> 1565     return self.sample(
   1566         input_ids,
   1567         logits_processor=logits_processor,
   1568         logits_warper=logits_warper,
   1569         stopping_criteria=stopping_criteria,
   1570         pad_token_id=generation_config.pad_token_id,
   1571         eos_token_id=generation_config.eos_token_id,
   1572         output_scores=generation_config.output_scores,
   1573         return_dict_in_generate=generation_config.return_dict_in_generate,
   1574         synced_gpus=synced_gpus,
   1575         streamer=streamer,
   1576         **model_kwargs,
   1577     )
   1579 elif is_beam_gen_mode:
   1580     if generation_config.num_return_sequences > generation_config.num_beams:

File ~/anaconda3/envs/llm/lib/python3.10/site-packages/transformers/generation/utils.py:2648, in GenerationMixin.sample(self, input_ids, logits_processor, stopping_criteria, logits_warper, max_length, pad_token_id, eos_token_id, output_attentions, output_hidden_states, output_scores, return_dict_in_generate, synced_gpus, streamer, **model_kwargs)
   2646 # sample
   2647 probs = nn.functional.softmax(next_token_scores, dim=-1)
-> 2648 next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1)
   2650 # finished sentences should have their next token be a padding token
   2651 if eos_token_id is not None:

RuntimeError: probability tensor contains either `inf`, `nan` or element < 0

Maybe this can provide more context? At the beginning you can see the prompt given: "Briefly explain to me what the Reimann Hypothesis is." Spits out a gibberish token, then fails. Any thoughts? @sam-mosaic

Mosaic ML, Inc. org

I'm not sure if this is the root cause, but we just added device_map support recently: https://huggingface.co/mosaicml/mpt-7b-instruct/discussions/41

@abhi-mosaic still seeing the same issue (RuntimeError: Triton Error [CUDA]: invalid argument) with auto device map. relevant packages/installs:

Python 3.10
CUDA 11.7
4x16GB T4 GPUs
einops==0.5.0
torch==1.13.1
transformers==4.29.2
triton-pre-mlir @ git+https://github.com/vchiley/triton.git@48b1cc9ff8b1f506ac32f2124471e2582875c008#subdirectory=python

Also, testing again on 4x16GB V100 GPUs with the same installs, I get another error as noted above: "RuntimeError: probability tensor contains either inf, nan or element < 0"

Mosaic ML, Inc. org

We have only tested triton on A10s and A100s. It may not be an option on either of those GPUs.

sam-mosaic changed discussion status to closed
sam-mosaic changed discussion status to open

@sam-mosaic is bfloat16 precision required for the triton implementation? T4s and V100s don't support bfloat16 precision, but I've tried with regular float16 precision as well and get the same error. so if the triton implementation can't run on regular float16 precision, then the lack of support for bfloat16 precision on those GPUs would explain this issue.

Sign up or log in to comment