Enrico Shippole commited on
Commit
48f8345
1 Parent(s): b2e2d75

Add initial gradio setup

Browse files
Files changed (2) hide show
  1. app.py +37 -0
  2. requirements.txt +3 -0
app.py ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from transformers import AutoTokenizer
3
+ from palm_rlhf_pytorch import PaLM
4
+ import gradio as gr
5
+
6
+ def generate(prompt, seq_len, temperature, filter_thres, model):
7
+ device = torch.device("cpu")
8
+
9
+ model = PaLM(
10
+ num_tokens=50304, dim=1024, depth=24, dim_head=128, heads=8, flash_attn=False, qk_rmsnorm = False,
11
+ ).to(device)
12
+
13
+ model.load('/palm_410m_8k_v0.pt')
14
+
15
+ tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-neox-20b")
16
+
17
+ encoded_text = tokenizer(prompt, return_tensors="pt")
18
+
19
+ output_tensor = model.generate(
20
+ seq_len=seq_len,
21
+ prompt=encoded_text["input_ids"].to(device),
22
+ temperature=temperature,
23
+ filter_thres=filter_thres,
24
+ pad_value=0.0,
25
+ eos_token=tokenizer.eos_token_id,
26
+ return_seq_without_prompt=False,
27
+ use_tqdm=True,
28
+ )
29
+
30
+ decoded_output = tokenizer.batch_decode(output_tensor, skip_special_tokens=True)
31
+
32
+ return decoded_output[0]
33
+
34
+
35
+
36
+ iface = gr.Interface(fn=generate, inputs="text", outputs="text")
37
+ iface.launch()
requirements.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ torch
2
+ transformers
3
+ palm-rlhf-pytorch