asahi417 commited on
Commit
87187b9
1 Parent(s): cd474b3
Files changed (3) hide show
  1. app.py +83 -4
  2. flagged/log.csv +3 -0
  3. requirements.txt +1 -0
app.py CHANGED
@@ -1,7 +1,86 @@
 
1
  import gradio as gr
2
 
3
- def greet(name):
4
- return "Hello " + name + "!!"
5
 
6
- iface = gr.Interface(fn=greet, inputs="text", outputs="text")
7
- iface.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from relbert import RelBERT
2
  import gradio as gr
3
 
4
+ model = RelBERT()
 
5
 
6
+
7
+ def cosine_similarity(a, b, zero_vector_mask: float = -100):
8
+ norm_a = sum(map(lambda x: x * x, a)) ** 0.5
9
+ norm_b = sum(map(lambda x: x * x, b)) ** 0.5
10
+ if norm_b * norm_a == 0:
11
+ return zero_vector_mask
12
+ return sum(map(lambda x: x[0] * x[1], zip(a, b)))/(norm_a * norm_b)
13
+
14
+
15
+ def greet(
16
+ query,
17
+ candidate_1,
18
+ candidate_2,
19
+ candidate_3,
20
+ candidate_4,
21
+ candidate_5,
22
+ candidate_6,
23
+ candidate_7,
24
+ candidate_8,
25
+ candidate_9,
26
+ candidate_10):
27
+ pairs = []
28
+ pairs_id = []
29
+ for n, i in enumerate([
30
+ candidate_1,
31
+ candidate_2,
32
+ candidate_3,
33
+ candidate_4,
34
+ candidate_5,
35
+ candidate_6,
36
+ candidate_7,
37
+ candidate_8,
38
+ candidate_9,
39
+ candidate_10
40
+ ]):
41
+ query = query.split(',')
42
+ # validate query
43
+ if len(query) == 0:
44
+ raise ValueError(f'ERROR: query is empty {query}')
45
+ if len(query) == 1:
46
+ raise ValueError(f'ERROR: query contains single word {query}')
47
+ if len(query) > 2:
48
+ raise ValueError(f'ERROR: query contains more than two word {query}')
49
+
50
+ if i != '':
51
+ if len(i.split(',')) != 1:
52
+ raise ValueError(f'ERROR: candidate {n + 1} contains single word {i.split(",")}')
53
+ if len(i.split(',')) > 2:
54
+ raise ValueError(f'ERROR: candidate {n + 1} contains more than two word {i.split(",")}')
55
+ pairs.append(i.split(','))
56
+ pairs_id.append(n+1)
57
+ if len(pairs_id) < 2:
58
+ raise ValueError(f'ERROR: please specify at least two candidates: {pairs}')
59
+ vectors = model.get_embedding(pairs+[query])
60
+ vector_q = vectors.pop(-1)
61
+ sims = []
62
+ for v in vectors:
63
+ sims.append(cosine_similarity(v, vector_q))
64
+ output = sorted(list(zip(pairs_id, sims, pairs)), key=lambda _x: _x[1], reverse=True)
65
+ output = {f'candidate {n}: [{p[0]}, {p[1]}]': s for n, (i, s, p) in enumerate(output)}
66
+ return output
67
+
68
+
69
+ demo = gr.Interface(
70
+ fn=greet,
71
+ inputs=[
72
+ gr.Textbox(lines=1, placeholder="Query Word Pair (separate by comma i.e. 'scotch whisky,wheat')"),
73
+ gr.Textbox(lines=1, placeholder="Candidate Word Pair 1 (separate by comma i.e. 'scotch whisky,wheat')"),
74
+ gr.Textbox(lines=1, placeholder="Candidate Word Pair 2 (separate by comma i.e. 'scotch whisky,wheat')"),
75
+ gr.Textbox(lines=1, placeholder="Candidate Word Pair 3 (separate by comma i.e. 'scotch whisky,wheat')"),
76
+ gr.Textbox(lines=1, placeholder="Candidate Word Pair 4 (separate by comma i.e. 'scotch whisky,wheat')"),
77
+ gr.Textbox(lines=1, placeholder="Candidate Word Pair 5 (separate by comma i.e. 'scotch whisky,wheat')"),
78
+ gr.Textbox(lines=1, placeholder="Candidate Word Pair 6 (separate by comma i.e. 'scotch whisky,wheat')"),
79
+ gr.Textbox(lines=1, placeholder="Candidate Word Pair 7 (separate by comma i.e. 'scotch whisky,wheat')"),
80
+ gr.Textbox(lines=1, placeholder="Candidate Word Pair 8 (separate by comma i.e. 'scotch whisky,wheat')"),
81
+ gr.Textbox(lines=1, placeholder="Candidate Word Pair 9 (separate by comma i.e. 'scotch whisky,wheat')"),
82
+ gr.Textbox(lines=1, placeholder="Candidate Word Pair 10 (separate by comma i.e. 'scotch whisky,wheat')")
83
+ ],
84
+ outputs="label",
85
+ )
86
+ demo.launch(show_error=True)
flagged/log.csv ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ 'query','candidate_1','candidate_2','candidate_3','candidate_4','candidate_5','candidate_6','candidate_7','candidate_8','candidate_9','candidate_10','output','flag','username','timestamp'
2
+ '','','','','','','','','','','','','','','2022-08-12 18:23:10.741127'
3
+ '','{"dog": 0.7, "cat": 0.3}','','','2022-08-12 18:59:42.932316'
requirements.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ relbert