ToletiSri commited on
Commit
7a202bc
1 Parent(s): a59382e

Upload 2 files

Browse files
Files changed (2) hide show
  1. app.py +47 -0
  2. saved_model.pth +1 -1
app.py ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from gpt import GPTLanguageModel
3
+ import torch
4
+ import config as cfg
5
+
6
+ torch.manual_seed(1337)
7
+
8
+ # wget https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt
9
+ with open('input.txt', 'r', encoding='utf-8') as f:
10
+ text = f.read()
11
+
12
+ # here are all the unique characters that occur in this text
13
+ chars = sorted(list(set(text)))
14
+ vocab_size = len(chars)
15
+ # create a mapping from characters to integers
16
+ stoi = { ch:i for i,ch in enumerate(chars) }
17
+ itos = { i:ch for i,ch in enumerate(chars) }
18
+ encode = lambda s: [stoi[c] for c in s] # encoder: take a string, output a list of integers
19
+ decode = lambda l: ''.join([itos[i] for i in l]) # decoder: take a list of integers, output a string
20
+
21
+ model = GPTLanguageModel(vocab_size)
22
+ model.load_state_dict(torch.load('saved_model.pth'))
23
+ m = model.to(cfg.device)
24
+
25
+ def inference(input_context, count):
26
+ encoded_text = [encode(input_context)]
27
+ count = int(count)
28
+ context = torch.tensor(encoded_text, dtype=torch.long, device=cfg.device)
29
+
30
+ print('--------------------context = ',context)
31
+ out_text = decode(m.generate(context, max_new_tokens=count)[0].tolist())
32
+ return out_text
33
+
34
+ title = "TSAI S21 Assignment: GPT training on mini shakespeare dataset"
35
+ description = "A simple Gradio interface that accepts a context and generates shakespere like text "
36
+
37
+
38
+ demo = gr.Interface(
39
+ inference,
40
+ inputs = [gr.Textbox(placeholder="Enter starting characters"), gr.Textbox(placeholder="Enter number of characters you want to generate")],
41
+ outputs = [gr.Textbox(label="Shakespeare like generated text")],
42
+ title = title,
43
+ description = description
44
+ )
45
+
46
+ demo.launch()
47
+
saved_model.pth CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:b4c2d92904deec61a54aea4d8507baa2963b113d8a4a29f2ccb5d5face8b2f03
3
  size 52672301
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d87d03e3d04ece587d28b1b383a06a4420d6e13d61c2f0e7b4682c9b40f93855
3
  size 52672301