Text Generation
Transformers
PyTorch
English
llama
custom_code
Inference Endpoints
text-generation-inference
Asif Ahmad commited on
Commit
b87bea2
1 Parent(s): 513a968

Update modeling_flash_llama.py

Browse files
Files changed (1) hide show
  1. modeling_flash_llama.py +1 -1
modeling_flash_llama.py CHANGED
@@ -68,7 +68,7 @@ def rmsnorm_func(hidden_states, weight, variance_epsilon):
68
  hidden_states = hidden_states.to(torch.float32)
69
  variance = hidden_states.pow(2).mean(-1, keepdim=True)
70
  hidden_states = hidden_states * torch.rsqrt(variance + variance_epsilon)
71
- return weight * hidden_states.to(input_dtype)
72
 
73
 
74
  class LlamaRMSNorm(nn.Module):
 
68
  hidden_states = hidden_states.to(torch.float32)
69
  variance = hidden_states.pow(2).mean(-1, keepdim=True)
70
  hidden_states = hidden_states * torch.rsqrt(variance + variance_epsilon)
71
+ return (weight * hidden_states).to(input_dtype)
72
 
73
 
74
  class LlamaRMSNorm(nn.Module):