wb-droid commited on
Commit
6ae201e
·
1 Parent(s): 6d5654f

new changes.

Browse files
Files changed (3) hide show
  1. app.py +85 -1
  2. myTextEmbedding.py +0 -1
  3. requirements.txt +0 -1
app.py CHANGED
@@ -1,6 +1,90 @@
1
- from myTextEmbedding import *
2
  import gradio as gr
3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4
  def generate_chunk_emb(m, chunk_data):
5
  with torch.no_grad():
6
  emb = m(chunk_data, device = "cpu")
 
1
+ #from myTextEmbedding import *
2
  import gradio as gr
3
 
4
+ import torch
5
+ import torch.nn as nn
6
+ from torch import tensor
7
+ from transformers import BertModel, BertTokenizer
8
+ #import gzip
9
+ #import pandas as pd
10
+ import requests
11
+
12
+
13
+ class EmbeddingModel(nn.Module):
14
+ def __init__(self, bertName = "bert-base-uncased"): # other bert models can also be supported
15
+ super().__init__()
16
+ self.bertName = bertName
17
+ # use BERT model
18
+ self.tokenizer = BertTokenizer.from_pretrained(self.bertName)
19
+ self.model = BertModel.from_pretrained(self.bertName)
20
+
21
+ def forward(self, s, device = "cuda"):
22
+ # get tokens, which also include attention_mask
23
+ tokens = self.tokenizer(s, return_tensors='pt', padding = "max_length", truncation = True, max_length = 256).to(device)
24
+
25
+ # get token embeddings
26
+ output = self.model(**tokens)
27
+ tokens_embeddings = output.last_hidden_state
28
+ #print("tokens_embeddings:" + str(tokens_embeddings.shape))
29
+
30
+ # mean pooling to get text embedding
31
+ embeddings = tokens_embeddings * tokens.attention_mask[...,None] # [B, T, emb]
32
+ #print("embeddings:" + str(embeddings.shape))
33
+
34
+ embeddings = embeddings.sum(1) # [B, emb]
35
+ valid_tokens = tokens.attention_mask.sum(1) # [B]
36
+ embeddings = embeddings / valid_tokens[...,None] # [B, emb]
37
+
38
+ return embeddings
39
+
40
+ # from scratch: nn.CosineSimilarity(dim = 1)(q,a)
41
+ def cos_score(self, q, a):
42
+ q_norm = q / (q.pow(2).sum(dim=1, keepdim=True).pow(0.5))
43
+ r_norm = a / (a.pow(2).sum(dim=1, keepdim=True).pow(0.5))
44
+ return (q_norm @ r_norm.T).diagonal()
45
+
46
+ # contrastive training
47
+ class TrainModel(nn.Module):
48
+ def __init__(self):
49
+ super().__init__()
50
+ self.m = EmbeddingModel("bert-base-uncased")
51
+
52
+ def forward(self, s1, s2, score):
53
+ cos_score = self.m.cos_score(self.m(s1), self.m(s2))
54
+ loss = nn.MSELoss()(cos_score, score)
55
+ return loss, cos_score
56
+
57
+ def searchWiki(s):
58
+ response = requests.get(
59
+ 'https://en.wikipedia.org/w/api.php',
60
+ params={
61
+ 'action': 'query',
62
+ 'format': 'json',
63
+ 'titles': s,
64
+ 'prop': 'extracts',
65
+ 'exintro': True,
66
+ 'explaintext': True,
67
+ }
68
+ ).json()
69
+ page = next(iter(response['query']['pages'].values()))
70
+ return page['extract'].replace("\n","")
71
+
72
+ # sentence chunking
73
+ def chunk(w):
74
+ return w.split(".")
75
+
76
+ def generate_chunk_data(concepts):
77
+ wiki_data = [searchWiki(c).replace("\n","") for c in concepts]
78
+ chunk_data = []
79
+ for w in wiki_data:
80
+ chunk_data = chunk_data + chunk(w)
81
+
82
+ chunk_data = [c.strip()+"." for c in chunk_data]
83
+ while '.' in chunk_data:
84
+ chunk_data.remove('.')
85
+
86
+ return chunk_data
87
+
88
  def generate_chunk_emb(m, chunk_data):
89
  with torch.no_grad():
90
  emb = m(chunk_data, device = "cpu")
myTextEmbedding.py CHANGED
@@ -2,7 +2,6 @@ import torch
2
  import torch.nn as nn
3
  from torch import tensor
4
  from transformers import BertModel, BertTokenizer
5
- #import gzip
6
  import pandas as pd
7
  import requests
8
 
 
2
  import torch.nn as nn
3
  from torch import tensor
4
  from transformers import BertModel, BertTokenizer
 
5
  import pandas as pd
6
  import requests
7
 
requirements.txt CHANGED
@@ -1,5 +1,4 @@
1
  torch
2
  transformers
3
- pandas
4
  requests
5
  gradio
 
1
  torch
2
  transformers
 
3
  requests
4
  gradio