Create app.py
Browse files
app.py
ADDED
@@ -0,0 +1,49 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
!git clone https://github.com/karpathy/minGPT/
|
2 |
+
|
3 |
+
import sys, os
|
4 |
+
sys.path.append('minGPT')
|
5 |
+
|
6 |
+
from transformers import GPT2Tokenizer, GPT2LMHeadModel
|
7 |
+
from mingpt.model import GPT
|
8 |
+
from minGPT.mingpt.utils import set_seed
|
9 |
+
|
10 |
+
use_mingpt = True # use minGPT or huggingface/transformers model?
|
11 |
+
model_type = 'gpt2-xl'
|
12 |
+
device = 'cuda'
|
13 |
+
|
14 |
+
if use_mingpt:
|
15 |
+
model = GPT.from_pretrained(model_type)
|
16 |
+
else:
|
17 |
+
model = GPT2LMHeadModel.from_pretrained(model_type)
|
18 |
+
model.config.pad_token_id = model.config.eos_token_id # suppress a warning
|
19 |
+
|
20 |
+
# ship model to device and set to eval mode
|
21 |
+
model.to(device)
|
22 |
+
model.eval();
|
23 |
+
|
24 |
+
|
25 |
+
def generate(prompt='', num_samples=10, steps=20, do_sample=True):
|
26 |
+
|
27 |
+
# tokenize the input prompt into integer input sequence
|
28 |
+
tokenizer = GPT2Tokenizer.from_pretrained(model_type)
|
29 |
+
if prompt == '': # to create unconditional samples we feed in the special start token
|
30 |
+
prompt = '<|endoftext|>'
|
31 |
+
encoded_input = tokenizer(prompt, return_tensors='pt').to(device)
|
32 |
+
x = encoded_input['input_ids']
|
33 |
+
x = x.expand(num_samples, -1)
|
34 |
+
|
35 |
+
# forward the model `steps` times to get samples, in a batch
|
36 |
+
y = model.generate(x, max_new_tokens=steps, do_sample=do_sample, top_k=40)
|
37 |
+
|
38 |
+
for i in range(num_samples):
|
39 |
+
out = tokenizer.decode(y[i].cpu().squeeze())
|
40 |
+
print('-'*80)
|
41 |
+
print(out)
|
42 |
+
|
43 |
+
|
44 |
+
def infer(input):
|
45 |
+
return generate(prompt=input, num_samples=3, steps=20)
|
46 |
+
|
47 |
+
import gradio as gr
|
48 |
+
|
49 |
+
gr.Interface(infer, "text", ["text", "text", "text"], examples=["I was commuting to my Silicon Valley job when I took a wrong turn. I"])
|