saichandrapandraju commited on
Commit
3e4ffd8
1 Parent(s): 275ba7e

add count based and single layer nn

Browse files
Files changed (6) hide show
  1. app.py +82 -0
  2. count_probs.pt +0 -0
  3. ctoi.json +1 -0
  4. itoc.json +1 -0
  5. requirements.txt +3 -0
  6. single_layer.pt +0 -0
app.py ADDED
@@ -0,0 +1,82 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import torch
3
+ import json
4
+ import torch.nn.functional as F
5
+ import pandas as pd
6
+
7
+ SEED = 42
8
+
9
+ @st.cache_resource
10
+ def init_count_model():
11
+ return torch.load("count_probs.pt")
12
+
13
+ @st.cache_resource
14
+ def init_single_layer_model():
15
+ return torch.load("single_layer.pt")
16
+
17
+ @st.cache_resource
18
+ def init_char_index_mappings():
19
+ with open("ctoi.json") as ci, open("itoc.json") as ic:
20
+ return json.load(ci), json.load(ic)
21
+
22
+ count_p = init_count_model()
23
+ single_layer_w = init_single_layer_model()
24
+ ctoi, itoc = init_char_index_mappings()
25
+
26
+ def predict_with_count(starting_char:str, num_words):
27
+ g = torch.Generator().manual_seed(SEED)
28
+ output = []
29
+ for _ in range(num_words):
30
+ prev = ctoi[starting_char]
31
+ out = []
32
+ out.append(starting_char)
33
+ while True:
34
+ p = count_p[prev]
35
+ pred = torch.multinomial(p, num_samples=1, replacement=True, generator=g).item()
36
+ out.append(itoc[str(pred)])
37
+ if pred==0:
38
+ break # end if '.' is predicted -> end of word
39
+ prev = pred
40
+ output.append(''.join(out[:-1])) # discard '.' at the end
41
+ return output
42
+
43
+ def predict_with_single_layer_nn(starting_char:str, num_words):
44
+ g = torch.Generator().manual_seed(SEED)
45
+ output = []
46
+ for _ in range(num_words):
47
+ out = []
48
+ ix = ctoi[starting_char]
49
+ out.append(starting_char)
50
+ while True:
51
+ xenc = F.one_hot(torch.tensor([ix]), num_classes=27).float()
52
+ logits = xenc @ single_layer_w
53
+ counts = logits.exp()
54
+ probs = counts/counts.sum(1, keepdim=True)
55
+
56
+ ix = torch.multinomial(probs, generator=g, replacement=True, num_samples=1).item()
57
+ out.append(itoc[str(ix)])
58
+ if ix==0:
59
+ break
60
+ output.append(''.join(out[:-1]))
61
+ return output
62
+
63
+ def predict(query, num_words):
64
+ preds = [predict_with_count(query, num_words), predict_with_single_layer_nn(query, num_words)]
65
+ labels = ["Count Based Language Model", "Single Linear Layer Language Model"]
66
+ results = {labels[idx]: preds[idx] for idx in range(len(preds))}
67
+ st.write(pd.DataFrame(results, index=range(num_words)))
68
+
69
+ # title and description
70
+ st.title("""
71
+ Make More Names.
72
+
73
+ This app creates the requested number of names starting with the input character below. The results will be predicted from the basic count based to advanced transformer based Character Level Language Model.""")
74
+
75
+ # search bar
76
+ query = st.text_input("Please input the starting character...", "", max_chars=1)
77
+
78
+ # number of words slider
79
+ num_words = st.slider("Number of names to generate:", min_value=1, max_value=50, value=5)
80
+
81
+ if query != "":
82
+ predict(query, num_words)
count_probs.pt ADDED
Binary file (3.83 kB). View file
 
ctoi.json ADDED
@@ -0,0 +1 @@
 
 
1
+ {"a": 1, "b": 2, "c": 3, "d": 4, "e": 5, "f": 6, "g": 7, "h": 8, "i": 9, "j": 10, "k": 11, "l": 12, "m": 13, "n": 14, "o": 15, "p": 16, "q": 17, "r": 18, "s": 19, "t": 20, "u": 21, "v": 22, "w": 23, "x": 24, "y": 25, "z": 26, ".": 0}
itoc.json ADDED
@@ -0,0 +1 @@
 
 
1
+ {"1": "a", "2": "b", "3": "c", "4": "d", "5": "e", "6": "f", "7": "g", "8": "h", "9": "i", "10": "j", "11": "k", "12": "l", "13": "m", "14": "n", "15": "o", "16": "p", "17": "q", "18": "r", "19": "s", "20": "t", "21": "u", "22": "v", "23": "w", "24": "x", "25": "y", "26": "z", "0": "."}
requirements.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ pandas
2
+ streamlit
3
+ torch
single_layer.pt ADDED
Binary file (3.84 kB). View file