Enrico Shippole commited on
Commit
8253e02
1 Parent(s): 6ee4952
Files changed (1) hide show
  1. app.py +5 -5
app.py CHANGED
@@ -13,14 +13,14 @@ def generate(prompt, seq_len=128, temperature=0.8, filter_thres=0.9):
13
  heads = 8
14
  flash_attn = True
15
 
16
- model = PaLM(
17
- num_tokens=num_tokens, dim=dim, depth=depth, dim_head=dim_head, heads=heads, flash_attn=flash_attn
18
- ).to(device).eval()
19
-
20
  # model = PaLM(
21
- # num_tokens=50304, dim=1024, depth=24, dim_head=128, heads=8, flash_attn=False, qk_rmsnorm = False,
22
  # ).to(device).eval()
23
 
 
 
 
 
24
  checkpoint = torch.load('./palm_1b_8k_v0.pt', map_location=device)
25
  model.load_state_dict(checkpoint)
26
 
 
13
  heads = 8
14
  flash_attn = True
15
 
 
 
 
 
16
  # model = PaLM(
17
+ # num_tokens=num_tokens, dim=dim, depth=depth, dim_head=dim_head, heads=heads, flash_attn=flash_attn
18
  # ).to(device).eval()
19
 
20
+ model = PaLM(
21
+ num_tokens=50304, dim=1024, depth=24, dim_head=128, heads=8, flash_attn=False, qk_rmsnorm = False,
22
+ ).to(device).eval()
23
+
24
  checkpoint = torch.load('./palm_1b_8k_v0.pt', map_location=device)
25
  model.load_state_dict(checkpoint)
26