saichandrapandraju commited on
Commit
e81673f
1 Parent(s): cc7c5b8

implement Embedding Based Single Layer LM

Browse files
Files changed (2) hide show
  1. app.py +33 -2
  2. mlp.pt +0 -0
app.py CHANGED
@@ -14,6 +14,11 @@ def init_count_model():
14
  def init_single_layer_model():
15
  return torch.load("single_layer.pt")
16
 
 
 
 
 
 
17
  @st.cache_resource
18
  def init_char_index_mappings():
19
  with open("ctoi.json") as ci, open("itoc.json") as ic:
@@ -21,6 +26,7 @@ def init_char_index_mappings():
21
 
22
  count_p = init_count_model()
23
  single_layer_w = init_single_layer_model()
 
24
  ctoi, itoc = init_char_index_mappings()
25
 
26
  def predict_with_count(starting_char:str, num_words):
@@ -64,10 +70,35 @@ def predict_with_single_layer_nn(starting_char:str, num_words):
64
  output.append(''.join(out[:-1]))
65
  return output
66
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
67
  def predict(query, num_words):
68
  try:
69
- preds = [predict_with_count(query, num_words), predict_with_single_layer_nn(query, num_words)]
70
- labels = ["Count Based Language Model", "Single Linear Layer Language Model"]
71
  results = {labels[idx]: preds[idx] for idx in range(len(preds))}
72
  st.write(pd.DataFrame(results, index=range(num_words)))
73
  except ValueError as e:
 
14
  def init_single_layer_model():
15
  return torch.load("single_layer.pt")
16
 
17
+ @st.cache_resource
18
+ def init_mlp():
19
+ mlp_layers = torch.load("mlp.pt")
20
+ return mlp_layers["emb"], mlp_layers['w1'], mlp_layers['b1'], mlp_layers['w2'], mlp_layers['b2']
21
+
22
  @st.cache_resource
23
  def init_char_index_mappings():
24
  with open("ctoi.json") as ci, open("itoc.json") as ic:
 
26
 
27
  count_p = init_count_model()
28
  single_layer_w = init_single_layer_model()
29
+ mlp_emb, mlp_w1, mlp_b1, mlp_w2, mlp_b2 = init_mlp()
30
  ctoi, itoc = init_char_index_mappings()
31
 
32
  def predict_with_count(starting_char:str, num_words):
 
70
  output.append(''.join(out[:-1]))
71
  return output
72
 
73
+ def predict_with_mlp(starting_char:str, num_words):
74
+ g = torch.Generator().manual_seed(SEED)
75
+ output = []
76
+ context_length = 3
77
+ for _ in range(num_words):
78
+ out = []
79
+ context = [0]*(context_length-1)
80
+ if starting_char not in ctoi:
81
+ raise ValueError("Starting Character is not a valid alphabet. Please input a valid alphabet.")
82
+ ix = ctoi[starting_char]
83
+ out.append(starting_char)
84
+ context+=[ix]
85
+ while True:
86
+ emb = mlp_emb[torch.tensor([context])]
87
+ h = torch.tanh(emb.view(1,-1) @ mlp_w1 + mlp_b1) # create batch_size 1
88
+ logits = h @ mlp_w2 + mlp_b2
89
+ probs = F.softmax(logits, dim=1)
90
+ ix = torch.multinomial(probs, num_samples=1, generator=g).item()
91
+ context = context[1:] + [ix]
92
+ out.append(itoc[str(ix)])
93
+ if ix == 0:
94
+ break
95
+ output.append(''.join(out[:-1]))
96
+ return output
97
+
98
  def predict(query, num_words):
99
  try:
100
+ preds = [predict_with_count(query, num_words), predict_with_single_layer_nn(query, num_words), predict_with_mlp(query, num_words)]
101
+ labels = ["Count Based LM", "Single Linear Layer LM", "Embedding Based Single Hidden Layer LM"]
102
  results = {labels[idx]: preds[idx] for idx in range(len(preds))}
103
  st.write(pd.DataFrame(results, index=range(num_words)))
104
  except ValueError as e:
mlp.pt ADDED
Binary file (49.3 kB). View file