asahi417 commited on
Commit
f58d9a0
1 Parent(s): 1a6de80
Files changed (1) hide show
  1. app.py +12 -4
app.py CHANGED
@@ -1,9 +1,18 @@
 
 
1
  from relbert import RelBERT
2
  import gradio as gr
3
 
4
  model = RelBERT(model='relbert/relbert-roberta-large')
5
 
6
 
 
 
 
 
 
 
 
7
  def cosine_similarity(a, b, zero_vector_mask: float = -100):
8
  norm_a = sum(map(lambda x: x * x, a)) ** 0.5
9
  norm_b = sum(map(lambda x: x * x, b)) ** 0.5
@@ -68,6 +77,8 @@ def greet(
68
  return output
69
 
70
 
 
 
71
  demo = gr.Interface(
72
  fn=greet,
73
  inputs=[
@@ -84,10 +95,7 @@ demo = gr.Interface(
84
  gr.Textbox(lines=1, placeholder="Candidate Word Pair 10 (separate by comma)")
85
  ],
86
  outputs="label",
87
- examples=[
88
- ["beauty,aesthete", "pleasure,hedonist", "emotion,demagogue", "opinion,sympathizer", "seance,medium", "luxury,ascetic"] + [''] * 5,
89
- ["classroom,desk", "bank,dollar", "church,pew", "studio,paintbrush", "museum,artifact"] + [''] * 6,
90
- ],
91
  )
92
  demo.launch(show_error=True)
93
 
 
1
+ import json
2
+ import requests
3
  from relbert import RelBERT
4
  import gradio as gr
5
 
6
  model = RelBERT(model='relbert/relbert-roberta-large')
7
 
8
 
9
+ def get_example():
10
+ url = "https://huggingface.co/datasets/relbert/analogy_questions/raw/main/dataset/sat/test.jsonl"
11
+ r = requests.get(url)
12
+ example = [json.loads(i) for i in r.content.decode().split('\n') if len(i) > 0]
13
+ return example
14
+
15
+
16
  def cosine_similarity(a, b, zero_vector_mask: float = -100):
17
  norm_a = sum(map(lambda x: x * x, a)) ** 0.5
18
  norm_b = sum(map(lambda x: x * x, b)) ** 0.5
 
77
  return output
78
 
79
 
80
+ examples = get_example()[:15]
81
+ examples = [[','.join(i['stem'])] + [','.join(c) for c in i['choice'] + [''] * (10 - len(i['choice']))] for i in examples]
82
  demo = gr.Interface(
83
  fn=greet,
84
  inputs=[
 
95
  gr.Textbox(lines=1, placeholder="Candidate Word Pair 10 (separate by comma)")
96
  ],
97
  outputs="label",
98
+ examples=examples
 
 
 
99
  )
100
  demo.launch(show_error=True)
101