sanjanatule commited on
Commit
ea6ef25
1 Parent(s): 3058282

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +17 -7
app.py CHANGED
@@ -7,15 +7,25 @@ from utils import GPTLM
7
 
8
  newmodel = GPTLM.load_from_checkpoint('shakespeare_gpt.pth')
9
 
10
- def generate_art(character_dropdown, seed_slider):
 
 
 
 
11
 
12
- if character_dropdown == "NONE":
13
- return "NULL"
14
- else:
15
- return "NOT NULL"
16
 
17
- return "Hello There!"
18
 
 
 
 
 
 
 
 
 
 
19
 
20
  HTML_TEMPLATE = """
21
  <style>
@@ -116,7 +126,7 @@ with gr.Blocks() as interface:
116
 
117
  with gr.Row():
118
  button = gr.Button("Generate Dialogue")
119
- button.click(generate_art, inputs=inputs, outputs=outputs)
120
 
121
 
122
  if __name__ == "__main__":
 
7
 
8
  newmodel = GPTLM.load_from_checkpoint('shakespeare_gpt.pth')
9
 
10
+ chars = ['\n', ' ', '!', '$', '&', "'", ',', '-', '.', '3', ':', ';', '?', 'A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'J', 'K', 'L', 'M', 'N', 'O', 'P', 'Q', 'R', 'S', 'T', 'U', 'V', 'W', 'X', 'Y', 'Z', 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z']
11
+ vocab_size = len(chars)
12
+ # create a mapping from characters to integers
13
+ stoi = { ch:i for i,ch in enumerate(chars) }
14
+ itos = { i:ch for i,ch in enumerate(chars) }
15
 
16
+ encode = lambda s: [stoi[c] for c in s] # encoder: take a string, output a list of integers
17
+ decode = lambda l: ''.join([itos[i] for i in l]) # decoder: take a list of integers, output a string
 
 
18
 
 
19
 
20
+ def generate_dialogue(character_dropdown, seed_slider):
21
+
22
+ if character_dropdown == "NONE":
23
+ context = torch.zeros((1, 1), dtype=torch.long)
24
+ return decode(newmodel.model.generate(context, max_new_tokens=100)[0].tolist())
25
+ else:
26
+ context = torch.tensor([encode(character_dropdown)], dtype=torch.long)
27
+ return decode(newmodel.model.generate(context, max_new_tokens=100)[0].tolist())
28
+
29
 
30
  HTML_TEMPLATE = """
31
  <style>
 
126
 
127
  with gr.Row():
128
  button = gr.Button("Generate Dialogue")
129
+ button.click(generate_dialogue, inputs=inputs, outputs=outputs)
130
 
131
 
132
  if __name__ == "__main__":