atrytone commited on
Commit
0f56fb9
1 Parent(s): c3141d7

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +60 -29
app.py CHANGED
@@ -4,31 +4,26 @@ from langchain.embeddings import HuggingFaceEmbeddings
4
  import torch
5
 
6
 
7
- def create_miread_embed(sents, bundle):
8
- tokenizer = bundle[0]
9
- model = bundle[1]
10
- model.cpu()
11
- tokens = tokenizer(sents,
12
- max_length=512,
13
- padding=True,
14
- truncation=True,
15
- return_tensors="pt"
16
- )
17
- device = torch.device('cpu')
18
- tokens = tokens.to(device)
19
- with torch.no_grad():
20
- out = model.bert(**tokens)
21
- feature = out.last_hidden_state[:, 0, :]
22
- return feature.cpu()
23
-
24
-
25
- def get_matches(query):
26
- matches = vecdb.similarity_search_with_score(query, k=60)
27
  return matches
28
 
 
 
 
 
 
 
 
29
 
30
- def inference(query):
31
- matches = get_matches(query)
 
 
 
 
 
 
32
  auth_counts = {}
33
  j_bucket = {}
34
  n_table = []
@@ -94,17 +89,40 @@ def inference(query):
94
 
95
  return [a_output, j_output, n_output]
96
 
 
 
 
 
 
 
 
 
 
97
 
98
- model_name = "biodatlab/MIReAD-Neuro-Contrastive"
 
 
99
  model_kwargs = {'device': 'cpu'}
100
  encode_kwargs = {'normalize_embeddings': False}
101
- faiss_embedder = HuggingFaceEmbeddings(
102
- model_name=model_name,
 
 
 
 
 
 
 
 
 
 
103
  model_kwargs=model_kwargs,
104
  encode_kwargs=encode_kwargs
105
  )
106
 
107
- vecdb = FAISS.load_local("nbdt_contr", faiss_embedder)
 
 
108
 
109
 
110
  with gr.Blocks(theme=gr.themes.Soft()) as demo:
@@ -114,11 +132,12 @@ with gr.Blocks(theme=gr.themes.Soft()) as demo:
114
  To find a recommendation, paste a `title[SEP]abstract` or `abstract` in the text box below and click \"Find Matches\".\
115
  Then, you can hover to authors/abstracts/journals tab to find a suggested list.\
116
  The data in our current demo includes authors associated with the NBDT Journal. We will update the data monthly for an up-to-date publications.")
117
- gr.Markdown("**Model on Deployment: " + model_name + "**")
118
 
119
  abst = gr.Textbox(label="Abstract", lines=10)
120
 
121
- action_btn = gr.Button(value="Find Matches")
 
 
122
 
123
  with gr.Tab("Authors"):
124
  n_output = gr.Dataframe(
@@ -147,7 +166,19 @@ with gr.Blocks(theme=gr.themes.Soft()) as demo:
147
  visible=False
148
  )
149
 
150
- action_btn.click(fn=inference,
 
 
 
 
 
 
 
 
 
 
 
 
151
  inputs=[
152
  abst,
153
  ],
 
4
  import torch
5
 
6
 
7
+ def get_matches1(query):
8
+ matches = vecdb1.similarity_search_with_score(query, k=60)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9
  return matches
10
 
11
+ def get_matches2(query):
12
+ matches = vecdb2.similarity_search_with_score(query, k=60)
13
+ return matches
14
+
15
+ def get_matches3(query):
16
+ matches = vecdb3.similarity_search_with_score(query, k=60)
17
+ return matches
18
 
19
+
20
+ def inference(query,model=1):
21
+ if model==1:
22
+ matches = get_matches1(query)
23
+ elif model==2:
24
+ matches = get_matches2(query)
25
+ else:
26
+ matches = get_matches3(query)
27
  auth_counts = {}
28
  j_bucket = {}
29
  n_table = []
 
89
 
90
  return [a_output, j_output, n_output]
91
 
92
+ def inference1(query):
93
+ return inference(query,1)
94
+
95
+ def inference2(query):
96
+ return inference(query,2)
97
+
98
+ def inference3(query):
99
+ return inference(query,3)
100
+
101
 
102
+ model1_name = "biodatlab/MIReAD-Neuro-Large"
103
+ model2_name = "biodatlab/MIReAD-Neuro-Contrastive"
104
+ model3_name = "biodatlab/SciBERT-Neuro-Contrastive"
105
  model_kwargs = {'device': 'cpu'}
106
  encode_kwargs = {'normalize_embeddings': False}
107
+ faiss_embedder1 = HuggingFaceEmbeddings(
108
+ model_name=model1_name,
109
+ model_kwargs=model_kwargs,
110
+ encode_kwargs=encode_kwargs
111
+ )
112
+ faiss_embedder2 = HuggingFaceEmbeddings(
113
+ model_name=model2_name,
114
+ model_kwargs=model_kwargs,
115
+ encode_kwargs=encode_kwargs
116
+ )
117
+ faiss_embedder3 = HuggingFaceEmbeddings(
118
+ model_name=model3_name,
119
  model_kwargs=model_kwargs,
120
  encode_kwargs=encode_kwargs
121
  )
122
 
123
+ vecdb1 = FAISS.load_local("miread_large", faiss_embedder1)
124
+ vecdb2 = FAISS.load_local("miread_contrastive", faiss_embedder2)
125
+ vecdb3 = FAISS.load_local("scibert_contrastive", faiss_embedder3)
126
 
127
 
128
  with gr.Blocks(theme=gr.themes.Soft()) as demo:
 
132
  To find a recommendation, paste a `title[SEP]abstract` or `abstract` in the text box below and click \"Find Matches\".\
133
  Then, you can hover to authors/abstracts/journals tab to find a suggested list.\
134
  The data in our current demo includes authors associated with the NBDT Journal. We will update the data monthly for an up-to-date publications.")
 
135
 
136
  abst = gr.Textbox(label="Abstract", lines=10)
137
 
138
+ action1_btn = gr.Button(value="Find Matches with MIReAD-Neuro-Large")
139
+ action2_btn = gr.Button(value="Find Matches with MIReAD-Neuro-Contrastive")
140
+ action3_btn = gr.Button(value="Find Matches with SciBERT-Neuro-Contrastive")
141
 
142
  with gr.Tab("Authors"):
143
  n_output = gr.Dataframe(
 
166
  visible=False
167
  )
168
 
169
+ action_btn1.click(fn=inference1,
170
+ inputs=[
171
+ abst,
172
+ ],
173
+ outputs=[a_output, j_output, n_output],
174
+ api_name="neurojane")
175
+ action_btn2.click(fn=inference2,
176
+ inputs=[
177
+ abst,
178
+ ],
179
+ outputs=[a_output, j_output, n_output],
180
+ api_name="neurojane")
181
+ action_btn3.click(fn=inference3,
182
  inputs=[
183
  abst,
184
  ],