orionweller commited on
Commit
7eba807
1 Parent(s): 61f7243
Files changed (2) hide show
  1. app.py +26 -35
  2. requirements.txt +3 -0
app.py CHANGED
@@ -1,56 +1,40 @@
1
  import gradio as gr
2
  from sentence_transformers import SentenceTransformer, util, CrossEncoder
3
- import torch
4
  from transformers import set_seed
5
  import numpy as np
6
- import pandas as pd
7
- import argparse
8
 
9
  set_seed(42)
10
 
 
11
 
12
- def calc_preferred_dense(doc1, doc2, q1, q2, model_name="dpr", model=None):
 
 
 
 
 
 
 
 
 
 
13
  """
14
  Input:
15
  doc1, doc2: strings containing the documents/passages
16
  query1, query2: strings for queries that are only relevant to the corresponding doc (doc1 -> q1, doc2 -> q2)
17
  model_name: string containing the type of model to run
18
- model: the preloaded model, if caching
19
 
20
  Returns:
21
  A dictionary containing each query (q1 or q2) and the score (P@1) for the pair
22
 
23
  """
24
- ### Model initialization
25
- if model_name == "dpr":
26
- model_type = "dpr"
27
- if model is not None:
28
- passage_encoder, query_encoder = model
29
- else:
30
- passage_encoder = SentenceTransformer(
31
- "facebook-dpr-ctx_encoder-multiset-base"
32
- )
33
- query_encoder = SentenceTransformer(
34
- "facebook-dpr-question_encoder-multiset-base"
35
- )
36
- elif "cross-encoder" in model_name or "t5" in model_name:
37
- model_type = "cross_encoder"
38
- if model is None:
39
- model = CrossEncoder(model_name)
40
- else:
41
- model_type = "biencoder"
42
- if model is not None:
43
- embedder = model
44
- else:
45
- embedder = SentenceTransformer(model_name)
46
-
47
  corpus = [doc1, doc2]
48
  queries = [q1, q2]
49
  results = {}
50
  num_correct = 0
51
 
52
  ### Do Retrieval
53
- if model_type == "dpr":
54
  passage_embeddings = passage_encoder.encode(corpus)
55
 
56
  query_encoder = SentenceTransformer(
@@ -69,7 +53,7 @@ def calc_preferred_dense(doc1, doc2, q1, q2, model_name="dpr", model=None):
69
  num_correct += 1
70
  model = (passage_encoder, query_encoder)
71
 
72
- elif model_type == "cross_encoder":
73
  for idx, query in enumerate(queries):
74
  scores = model.predict([[query, doc1], [query, doc2]])
75
  results[f"q{idx+1}"] = scores.tolist()
@@ -100,11 +84,18 @@ def calc_preferred_dense(doc1, doc2, q1, q2, model_name="dpr", model=None):
100
  model = embedder
101
 
102
  results["score"] = num_correct / 2
103
- return results, model
104
 
105
 
106
- def greet(name):
107
- return "Hello " + name + "!!"
 
 
 
 
 
 
 
 
 
108
 
109
- iface = gr.Interface(fn=calc_preferred_dense, inputs="text", outputs="text")
110
- iface.launch()
 
1
  import gradio as gr
2
  from sentence_transformers import SentenceTransformer, util, CrossEncoder
 
3
  from transformers import set_seed
4
  import numpy as np
 
 
5
 
6
  set_seed(42)
7
 
8
+
9
 
10
+ passage_encoder = SentenceTransformer(
11
+ "facebook-dpr-ctx_encoder-multiset-base"
12
+ )
13
+ query_encoder = SentenceTransformer(
14
+ "facebook-dpr-question_encoder-multiset-base"
15
+ )
16
+ model = CrossEncoder("cross-encoder/ms-marco-MiniLM-L-12-v2")
17
+ embedder = SentenceTransformer("all-mpnet-base-v2")
18
+
19
+
20
+ def calc_preferred_dense(doc1, doc2, q1, q2, model_name="dpr"):
21
  """
22
  Input:
23
  doc1, doc2: strings containing the documents/passages
24
  query1, query2: strings for queries that are only relevant to the corresponding doc (doc1 -> q1, doc2 -> q2)
25
  model_name: string containing the type of model to run
 
26
 
27
  Returns:
28
  A dictionary containing each query (q1 or q2) and the score (P@1) for the pair
29
 
30
  """
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
31
  corpus = [doc1, doc2]
32
  queries = [q1, q2]
33
  results = {}
34
  num_correct = 0
35
 
36
  ### Do Retrieval
37
+ if model_name == "dpr":
38
  passage_embeddings = passage_encoder.encode(corpus)
39
 
40
  query_encoder = SentenceTransformer(
 
53
  num_correct += 1
54
  model = (passage_encoder, query_encoder)
55
 
56
+ elif model_name == "cross_encoder":
57
  for idx, query in enumerate(queries):
58
  scores = model.predict([[query, doc1], [query, doc2]])
59
  results[f"q{idx+1}"] = scores.tolist()
 
84
  model = embedder
85
 
86
  results["score"] = num_correct / 2
87
+ return results
88
 
89
 
90
+ gr.Interface(
91
+ calc_preferred_dense,
92
+ [ gr.Textbox(label="Sentence 1"), gr.Textbox(label="Sentence 2"), gr.Dropdown(["dpr", "cross-encoder", "dense"], value="cross-encoder")],
93
+ [ gr.components.Label(label="Similarity score") ],
94
+ title="Similarity score between 2 sentences",
95
+ description="In this demo do provide 2 sentences bellow. They can even be in distinct languages. Powered by S-BERT multilingual model : https://www.sbert.net.",
96
+ examples=[['The sentences are mapped such that sentences with similar meanings are close in vector space.', 'Les phrases sont mappées de manière à ce que les phrases ayant des significations similaires soient proches dans l\'espace vectoriel.'],
97
+ ['You do not need to specify the input language.', 'You can use any language.']],
98
+ live=True,
99
+ allow_flagging="never"
100
+ ).launch(debug=True, enable_queue=True)
101
 
 
 
requirements.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ gradio
2
+ sentence_transformers
3
+ numpy