Enrico Shippole commited on
Commit
6ee4952
1 Parent(s): 82872e3
Files changed (1) hide show
  1. app.py +13 -2
app.py CHANGED
@@ -6,11 +6,22 @@ import gradio as gr
6
  def generate(prompt, seq_len=128, temperature=0.8, filter_thres=0.9):
7
  device = torch.device("cpu")
8
 
 
 
 
 
 
 
 
9
  model = PaLM(
10
- num_tokens=50304, dim=1024, depth=24, dim_head=128, heads=8, flash_attn=False, qk_rmsnorm = False,
11
  ).to(device).eval()
12
 
13
- checkpoint = torch.load('./palm_410m_8k_v0.pt', map_location=device)
 
 
 
 
14
  model.load_state_dict(checkpoint)
15
 
16
  tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-neox-20b")
 
6
  def generate(prompt, seq_len=128, temperature=0.8, filter_thres=0.9):
7
  device = torch.device("cpu")
8
 
9
+ num_tokens = 50304
10
+ dim = 2048
11
+ depth = 16
12
+ dim_head = 128
13
+ heads = 8
14
+ flash_attn = True
15
+
16
  model = PaLM(
17
+ num_tokens=num_tokens, dim=dim, depth=depth, dim_head=dim_head, heads=heads, flash_attn=flash_attn
18
  ).to(device).eval()
19
 
20
+ # model = PaLM(
21
+ # num_tokens=50304, dim=1024, depth=24, dim_head=128, heads=8, flash_attn=False, qk_rmsnorm = False,
22
+ # ).to(device).eval()
23
+
24
+ checkpoint = torch.load('./palm_1b_8k_v0.pt', map_location=device)
25
  model.load_state_dict(checkpoint)
26
 
27
  tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-neox-20b")