Error when using attn_impl triton
#27
by
Wraken
- opened
Hey !
I have problems using triton.
I try to test this model with just this few lines of code :
from transformers import AutoModelForCausalLM
from transformers import AutoTokenizer, AutoConfig
import torch
import time
# loac config
config = AutoConfig.from_pretrained(
"replit/replit-code-v1-3b",
trust_remote_code=True,
low_cpu_mem_usage=True
)
config.attn_config['attn_impl'] = 'triton'
# load model
model = AutoModelForCausalLM.from_pretrained('replit/replit-code-v1-3b',
trust_remote_code=True,
config=config,
).to(device='cuda:0', dtype=torch.float16)
# load tokenizer
tokenizer = AutoTokenizer.from_pretrained('replit/replit-code-v1-3b', trust_remote_code=True)
prompt = 'func main() {'
start_time = time.time()
inputs = tokenizer.encode(prompt, return_tensors='pt').to(device='cuda:0')
print(inputs, inputs.size(), inputs.size(dim=1))
outputs = model.generate(inputs, do_sample=True, max_length=100, top_p=0.95, top_k=4, temperature=0.2, num_return_sequences=4, eos_token_id=tokenizer.eos_token_id, pad_token_id=tokenizer.pad_token_id)
prompt_len = inputs.size(dim=1)
print(prompt_len)
output_lens = [len(o)-prompt_len for o in outputs]
print(output_lens)
decoded = tokenizer.batch_decode([out[prompt_len:prompt_len + g] for g,out in zip(output_lens, outputs)])
for i,d in enumerate(decoded):
print(i,d)
I face an error that I can't fix, I've tried a lot of version of libs but can't fix it.
Currently I'm using : triton==2.0.0.dev20221202
and flash-attn==v1.0.3.post0
Traceback (most recent call last):
File "<string>", line 21, in _fwd_kernel
KeyError: ('2-.-0-.-0-d6e5675c89b63c389326c8b846421ab2-2b0c5161c53c71b37ae20a9996ee4bb8-c1f92808b4e4644c1732e8338187ac87-d962222789c30252d492a16cca3bf467-12f7ac1ca211e037f62a7c0c323d9990-5c5e32ff210f3b7f56c98ca29917c25e-06f0df2d61979d629033f4a22eff5198-0dd03b0bd512a184b3512b278d9dfa59-d35ab04ae841e2714a253c523530b071', (torch.float16, torch.float16, torch.float16, torch.float16, torch.float16, torch.float32, torch.float32, 'fp32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32'), ('vector', True, 128, False, False, False, 128, 128), (True, True, True, True, True, True, True, (False,), (True, False), (True, False), (True, False), (True, False), (True, False), (True, False), (True, False), (True, False), (True, False), (True, False), (False, False), (True, False), (True, False), (True, False), (True, False), (True, False), (False, False), (False, False), (True, False), (True, False), (True, False), (True, False)))
During handling of the above exception, another exception occurred:
Traceback (most recent call last):
File "/home/wraken/FreePilot/test.py", line 52, in <module>
outputs = model.generate(inputs, do_sample=True, max_length=100, top_p=0.95, top_k=4, temperature=0.2, num_return_sequences=4, eos_token_id=tokenizer.eos_token_id, pad_token_id=tokenizer.pad_token_id)
File "/home/wraken/.local/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
return func(*args, **kwargs)
File "/home/wraken/.local/lib/python3.10/site-packages/transformers/generation/utils.py", line 1565, in generate
return self.sample(
File "/home/wraken/.local/lib/python3.10/site-packages/transformers/generation/utils.py", line 2612, in sample
outputs = self(
File "/home/wraken/.local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
return forward_call(*args, **kwargs)
File "/home/wraken/.cache/huggingface/modules/transformers_modules/replit/replit-code-v1-3b/cecad1ade06e7f4db074893929778ae4bc9ed279/modeling_mpt.py", line 239, in forward
outputs = self.transformer(input_ids=input_ids, past_key_values=past_key_values, attention_mask=attention_mask, prefix_mask=prefix_mask, sequence_id=sequence_id, return_dict=return_dict, output_attentions=output_attentions, output_hidden_states=output_hidden_states, use_cache=use_cache)
File "/home/wraken/.local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
return forward_call(*args, **kwargs)
File "/home/wraken/.cache/huggingface/modules/transformers_modules/replit/replit-code-v1-3b/cecad1ade06e7f4db074893929778ae4bc9ed279/modeling_mpt.py", line 185, in forward
(x, past_key_value) = block(x, past_key_value=past_key_value, attn_bias=attn_bias, attention_mask=attention_mask, is_causal=self.is_causal)
File "/home/wraken/.local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
return forward_call(*args, **kwargs)
File "/home/wraken/.cache/huggingface/modules/transformers_modules/replit/replit-code-v1-3b/cecad1ade06e7f4db074893929778ae4bc9ed279/blocks.py", line 36, in forward
(b, _, past_key_value) = self.attn(a, past_key_value=past_key_value, attn_bias=attn_bias, attention_mask=attention_mask, is_causal=is_causal)
File "/home/wraken/.local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
return forward_call(*args, **kwargs)
File "/home/wraken/.cache/huggingface/modules/transformers_modules/replit/replit-code-v1-3b/cecad1ade06e7f4db074893929778ae4bc9ed279/attention.py", line 172, in forward
(context, attn_weights) = self.attn_fn(query, key, value, self.n_heads, softmax_scale=self.softmax_scale, attn_bias=attn_bias, key_padding_mask=key_padding_mask, is_causal=is_causal, dropout_p=self.attn_dropout_p, training=self.training, needs_weights=needs_weights)
File "/home/wraken/.cache/huggingface/modules/transformers_modules/replit/replit-code-v1-3b/cecad1ade06e7f4db074893929778ae4bc9ed279/attention.py", line 111, in triton_flash_attn_fn
attn_output = flash_attn_triton.flash_attn_func(query, key, value, attn_bias, reset_is_causal, softmax_scale)
File "/home/wraken/.local/lib/python3.10/site-packages/torch/autograd/function.py", line 506, in apply
return super().apply(*args, **kwargs) # type: ignore[misc]
File "/home/wraken/.local/lib/python3.10/site-packages/flash_attn/flash_attn_triton.py", line 810, in forward
o, lse, ctx.softmax_scale = _flash_attn_forward(
File "/home/wraken/.local/lib/python3.10/site-packages/flash_attn/flash_attn_triton.py", line 623, in _flash_attn_forward
_fwd_kernel[grid](
File "/home/wraken/.local/lib/python3.10/site-packages/triton/runtime/jit.py", line 106, in launcher
return self.run(*args, grid=grid, **kwargs)
File "/home/wraken/.local/lib/python3.10/site-packages/triton/runtime/autotuner.py", line 200, in run
return self.fn.run(*args, **kwargs)
File "<string>", line 41, in _fwd_kernel
File "/home/wraken/.local/lib/python3.10/site-packages/triton/compiler.py", line 1256, in compile
asm, shared, kernel_name = _compile(fn, signature, device, constants, configs[0], num_warps, num_stages,
File "/home/wraken/.local/lib/python3.10/site-packages/triton/compiler.py", line 901, in _compile
name, asm, shared_mem = _triton.code_gen.compile_ttir(backend, module, device, num_warps, num_stages, extern_libs, cc)
RuntimeError: CUDA: Error- no device
I'm using ubuntu on wsl2 and have a RTX3080.
Does someone know of to fix this ?
Thanks :)
Okay I think the problem was ubuntu on WSL2, it works fine on ubuntu system.
Wraken
changed discussion status to
closed