Error when using attn_impl triton

#27
by Wraken - opened

Hey !

I have problems using triton.
I try to test this model with just this few lines of code :

from transformers import AutoModelForCausalLM
from transformers import AutoTokenizer, AutoConfig
import torch
import time

# loac config
config = AutoConfig.from_pretrained(
    "replit/replit-code-v1-3b",
    trust_remote_code=True,
    low_cpu_mem_usage=True
)
config.attn_config['attn_impl'] = 'triton'

# load model
model = AutoModelForCausalLM.from_pretrained('replit/replit-code-v1-3b',
                                            trust_remote_code=True,
                                            config=config,
                                            ).to(device='cuda:0', dtype=torch.float16)

# load tokenizer
tokenizer = AutoTokenizer.from_pretrained('replit/replit-code-v1-3b', trust_remote_code=True)

prompt = 'func main() {'

start_time = time.time()

inputs = tokenizer.encode(prompt, return_tensors='pt').to(device='cuda:0')
print(inputs, inputs.size(), inputs.size(dim=1))

outputs = model.generate(inputs, do_sample=True, max_length=100, top_p=0.95, top_k=4, temperature=0.2, num_return_sequences=4, eos_token_id=tokenizer.eos_token_id, pad_token_id=tokenizer.pad_token_id)

prompt_len = inputs.size(dim=1)
print(prompt_len)
output_lens = [len(o)-prompt_len for o in outputs]
print(output_lens)
decoded = tokenizer.batch_decode([out[prompt_len:prompt_len + g] for g,out in zip(output_lens, outputs)])

for i,d in enumerate(decoded):
    print(i,d)

I face an error that I can't fix, I've tried a lot of version of libs but can't fix it.
Currently I'm using : triton==2.0.0.dev20221202 and flash-attn==v1.0.3.post0

Traceback (most recent call last):
  File "<string>", line 21, in _fwd_kernel
KeyError: ('2-.-0-.-0-d6e5675c89b63c389326c8b846421ab2-2b0c5161c53c71b37ae20a9996ee4bb8-c1f92808b4e4644c1732e8338187ac87-d962222789c30252d492a16cca3bf467-12f7ac1ca211e037f62a7c0c323d9990-5c5e32ff210f3b7f56c98ca29917c25e-06f0df2d61979d629033f4a22eff5198-0dd03b0bd512a184b3512b278d9dfa59-d35ab04ae841e2714a253c523530b071', (torch.float16, torch.float16, torch.float16, torch.float16, torch.float16, torch.float32, torch.float32, 'fp32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32'), ('vector', True, 128, False, False, False, 128, 128), (True, True, True, True, True, True, True, (False,), (True, False), (True, False), (True, False), (True, False), (True, False), (True, False), (True, False), (True, False), (True, False), (True, False), (False, False), (True, False), (True, False), (True, False), (True, False), (True, False), (False, False), (False, False), (True, False), (True, False), (True, False), (True, False)))

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "/home/wraken/FreePilot/test.py", line 52, in <module>
    outputs = model.generate(inputs, do_sample=True, max_length=100, top_p=0.95, top_k=4, temperature=0.2, num_return_sequences=4, eos_token_id=tokenizer.eos_token_id, pad_token_id=tokenizer.pad_token_id)
  File "/home/wraken/.local/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
    return func(*args, **kwargs)
  File "/home/wraken/.local/lib/python3.10/site-packages/transformers/generation/utils.py", line 1565, in generate
    return self.sample(
  File "/home/wraken/.local/lib/python3.10/site-packages/transformers/generation/utils.py", line 2612, in sample
    outputs = self(
  File "/home/wraken/.local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/wraken/.cache/huggingface/modules/transformers_modules/replit/replit-code-v1-3b/cecad1ade06e7f4db074893929778ae4bc9ed279/modeling_mpt.py", line 239, in forward
    outputs = self.transformer(input_ids=input_ids, past_key_values=past_key_values, attention_mask=attention_mask, prefix_mask=prefix_mask, sequence_id=sequence_id, return_dict=return_dict, output_attentions=output_attentions, output_hidden_states=output_hidden_states, use_cache=use_cache)
  File "/home/wraken/.local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/wraken/.cache/huggingface/modules/transformers_modules/replit/replit-code-v1-3b/cecad1ade06e7f4db074893929778ae4bc9ed279/modeling_mpt.py", line 185, in forward
    (x, past_key_value) = block(x, past_key_value=past_key_value, attn_bias=attn_bias, attention_mask=attention_mask, is_causal=self.is_causal)
  File "/home/wraken/.local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/wraken/.cache/huggingface/modules/transformers_modules/replit/replit-code-v1-3b/cecad1ade06e7f4db074893929778ae4bc9ed279/blocks.py", line 36, in forward
    (b, _, past_key_value) = self.attn(a, past_key_value=past_key_value, attn_bias=attn_bias, attention_mask=attention_mask, is_causal=is_causal)
  File "/home/wraken/.local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/wraken/.cache/huggingface/modules/transformers_modules/replit/replit-code-v1-3b/cecad1ade06e7f4db074893929778ae4bc9ed279/attention.py", line 172, in forward
    (context, attn_weights) = self.attn_fn(query, key, value, self.n_heads, softmax_scale=self.softmax_scale, attn_bias=attn_bias, key_padding_mask=key_padding_mask, is_causal=is_causal, dropout_p=self.attn_dropout_p, training=self.training, needs_weights=needs_weights)
  File "/home/wraken/.cache/huggingface/modules/transformers_modules/replit/replit-code-v1-3b/cecad1ade06e7f4db074893929778ae4bc9ed279/attention.py", line 111, in triton_flash_attn_fn
    attn_output = flash_attn_triton.flash_attn_func(query, key, value, attn_bias, reset_is_causal, softmax_scale)
  File "/home/wraken/.local/lib/python3.10/site-packages/torch/autograd/function.py", line 506, in apply
    return super().apply(*args, **kwargs)  # type: ignore[misc]
  File "/home/wraken/.local/lib/python3.10/site-packages/flash_attn/flash_attn_triton.py", line 810, in forward
    o, lse, ctx.softmax_scale = _flash_attn_forward(
  File "/home/wraken/.local/lib/python3.10/site-packages/flash_attn/flash_attn_triton.py", line 623, in _flash_attn_forward
    _fwd_kernel[grid](
  File "/home/wraken/.local/lib/python3.10/site-packages/triton/runtime/jit.py", line 106, in launcher
    return self.run(*args, grid=grid, **kwargs)
  File "/home/wraken/.local/lib/python3.10/site-packages/triton/runtime/autotuner.py", line 200, in run
    return self.fn.run(*args, **kwargs)
  File "<string>", line 41, in _fwd_kernel
  File "/home/wraken/.local/lib/python3.10/site-packages/triton/compiler.py", line 1256, in compile
    asm, shared, kernel_name = _compile(fn, signature, device, constants, configs[0], num_warps, num_stages,
  File "/home/wraken/.local/lib/python3.10/site-packages/triton/compiler.py", line 901, in _compile
    name, asm, shared_mem = _triton.code_gen.compile_ttir(backend, module, device, num_warps, num_stages, extern_libs, cc)
RuntimeError: CUDA: Error- no device

I'm using ubuntu on wsl2 and have a RTX3080.

Does someone know of to fix this ?

Thanks :)

Okay I think the problem was ubuntu on WSL2, it works fine on ubuntu system.

Wraken changed discussion status to closed

Sign up or log in to comment