Spaces:
Runtime error
Runtime error
saichandrapandraju
commited on
Commit
•
3e4ffd8
1
Parent(s):
275ba7e
add count based and single layer nn
Browse files- app.py +82 -0
- count_probs.pt +0 -0
- ctoi.json +1 -0
- itoc.json +1 -0
- requirements.txt +3 -0
- single_layer.pt +0 -0
app.py
ADDED
@@ -0,0 +1,82 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import streamlit as st
|
2 |
+
import torch
|
3 |
+
import json
|
4 |
+
import torch.nn.functional as F
|
5 |
+
import pandas as pd
|
6 |
+
|
7 |
+
SEED = 42
|
8 |
+
|
9 |
+
@st.cache_resource
|
10 |
+
def init_count_model():
|
11 |
+
return torch.load("count_probs.pt")
|
12 |
+
|
13 |
+
@st.cache_resource
|
14 |
+
def init_single_layer_model():
|
15 |
+
return torch.load("single_layer.pt")
|
16 |
+
|
17 |
+
@st.cache_resource
|
18 |
+
def init_char_index_mappings():
|
19 |
+
with open("ctoi.json") as ci, open("itoc.json") as ic:
|
20 |
+
return json.load(ci), json.load(ic)
|
21 |
+
|
22 |
+
count_p = init_count_model()
|
23 |
+
single_layer_w = init_single_layer_model()
|
24 |
+
ctoi, itoc = init_char_index_mappings()
|
25 |
+
|
26 |
+
def predict_with_count(starting_char:str, num_words):
|
27 |
+
g = torch.Generator().manual_seed(SEED)
|
28 |
+
output = []
|
29 |
+
for _ in range(num_words):
|
30 |
+
prev = ctoi[starting_char]
|
31 |
+
out = []
|
32 |
+
out.append(starting_char)
|
33 |
+
while True:
|
34 |
+
p = count_p[prev]
|
35 |
+
pred = torch.multinomial(p, num_samples=1, replacement=True, generator=g).item()
|
36 |
+
out.append(itoc[str(pred)])
|
37 |
+
if pred==0:
|
38 |
+
break # end if '.' is predicted -> end of word
|
39 |
+
prev = pred
|
40 |
+
output.append(''.join(out[:-1])) # discard '.' at the end
|
41 |
+
return output
|
42 |
+
|
43 |
+
def predict_with_single_layer_nn(starting_char:str, num_words):
|
44 |
+
g = torch.Generator().manual_seed(SEED)
|
45 |
+
output = []
|
46 |
+
for _ in range(num_words):
|
47 |
+
out = []
|
48 |
+
ix = ctoi[starting_char]
|
49 |
+
out.append(starting_char)
|
50 |
+
while True:
|
51 |
+
xenc = F.one_hot(torch.tensor([ix]), num_classes=27).float()
|
52 |
+
logits = xenc @ single_layer_w
|
53 |
+
counts = logits.exp()
|
54 |
+
probs = counts/counts.sum(1, keepdim=True)
|
55 |
+
|
56 |
+
ix = torch.multinomial(probs, generator=g, replacement=True, num_samples=1).item()
|
57 |
+
out.append(itoc[str(ix)])
|
58 |
+
if ix==0:
|
59 |
+
break
|
60 |
+
output.append(''.join(out[:-1]))
|
61 |
+
return output
|
62 |
+
|
63 |
+
def predict(query, num_words):
|
64 |
+
preds = [predict_with_count(query, num_words), predict_with_single_layer_nn(query, num_words)]
|
65 |
+
labels = ["Count Based Language Model", "Single Linear Layer Language Model"]
|
66 |
+
results = {labels[idx]: preds[idx] for idx in range(len(preds))}
|
67 |
+
st.write(pd.DataFrame(results, index=range(num_words)))
|
68 |
+
|
69 |
+
# title and description
|
70 |
+
st.title("""
|
71 |
+
Make More Names.
|
72 |
+
|
73 |
+
This app creates the requested number of names starting with the input character below. The results will be predicted from the basic count based to advanced transformer based Character Level Language Model.""")
|
74 |
+
|
75 |
+
# search bar
|
76 |
+
query = st.text_input("Please input the starting character...", "", max_chars=1)
|
77 |
+
|
78 |
+
# number of words slider
|
79 |
+
num_words = st.slider("Number of names to generate:", min_value=1, max_value=50, value=5)
|
80 |
+
|
81 |
+
if query != "":
|
82 |
+
predict(query, num_words)
|
count_probs.pt
ADDED
Binary file (3.83 kB). View file
|
|
ctoi.json
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
{"a": 1, "b": 2, "c": 3, "d": 4, "e": 5, "f": 6, "g": 7, "h": 8, "i": 9, "j": 10, "k": 11, "l": 12, "m": 13, "n": 14, "o": 15, "p": 16, "q": 17, "r": 18, "s": 19, "t": 20, "u": 21, "v": 22, "w": 23, "x": 24, "y": 25, "z": 26, ".": 0}
|
itoc.json
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
{"1": "a", "2": "b", "3": "c", "4": "d", "5": "e", "6": "f", "7": "g", "8": "h", "9": "i", "10": "j", "11": "k", "12": "l", "13": "m", "14": "n", "15": "o", "16": "p", "17": "q", "18": "r", "19": "s", "20": "t", "21": "u", "22": "v", "23": "w", "24": "x", "25": "y", "26": "z", "0": "."}
|
requirements.txt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
pandas
|
2 |
+
streamlit
|
3 |
+
torch
|
single_layer.pt
ADDED
Binary file (3.84 kB). View file
|
|