File size: 3,179 Bytes
f58d9a0
 
c31d68b
87187b9
cd474b3
 
dba7731
cd474b3
87187b9
f58d9a0
 
 
 
 
 
 
87187b9
 
 
 
 
 
 
 
c31d68b
 
 
 
 
 
87187b9
 
 
 
 
 
 
e2905f0
c31d68b
7c630e0
 
 
 
 
 
 
 
87187b9
 
 
 
 
 
 
 
e2905f0
87187b9
7c630e0
 
c31d68b
7c630e0
 
 
 
 
 
 
 
 
 
 
 
 
217ce82
7c630e0
87187b9
 
f58d9a0
e2905f0
87187b9
 
 
ac00d10
 
 
 
 
 
 
87187b9
 
f58d9a0
87187b9
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
import json
import requests
import re
from relbert import RelBERT
import gradio as gr

model = RelBERT(model='relbert/relbert-roberta-large')


def get_example():
    url = "https://huggingface.co/datasets/relbert/analogy_questions/raw/main/dataset/sat/test.jsonl"
    r = requests.get(url)
    example = [json.loads(i) for i in r.content.decode().split('\n') if len(i) > 0]
    return example


def cosine_similarity(a, b, zero_vector_mask: float = -100):
    norm_a = sum(map(lambda x: x * x, a)) ** 0.5
    norm_b = sum(map(lambda x: x * x, b)) ** 0.5
    if norm_b * norm_a == 0:
        return zero_vector_mask
    return sum(map(lambda x: x[0] * x[1], zip(a, b)))/(norm_a * norm_b)


def clean(text):
    text = re.sub(r"\A\s+", "", text)
    text = re.sub(r"\s+\Z", "", text)
    return text


def greet(
        query,
        candidate_1,
        candidate_2,
        candidate_3,
        candidate_4,
        candidate_5,
        candidate_6):
    query = [clean(i) for i in query.split(',')]
    # validate query
    if len(query) == 0:
        raise ValueError(f'ERROR: query is empty {query}')
    if len(query) == 1:
        raise ValueError(f'ERROR: query contains single word {query}')
    if len(query) > 2:
        raise ValueError(f'ERROR: query contains more than two word {query}')

    pairs = []
    pairs_id = []
    for n, i in enumerate([
        candidate_1,
        candidate_2,
        candidate_3,
        candidate_4,
        candidate_5,
        candidate_6
    ]):
        if i == '':
            continue
        candidate = [clean(x) for x in i.split(',')]
        if len(candidate) == 1:
            raise ValueError(f'ERROR: candidate {n + 1} contains single word {candidate}')
        if len(candidate) > 2:
            raise ValueError(f'ERROR: candidate {n + 1} contains more than two word {candidate}')
        pairs.append(candidate)
        pairs_id.append(n+1)
    if len(pairs_id) < 2:
        raise ValueError(f'ERROR: please specify at least two candidates: {pairs}')
    vectors = model.get_embedding(pairs+[query])
    vector_q = vectors.pop(-1)
    sims = []
    for v in vectors:
        sims.append(cosine_similarity(v, vector_q))
    output = {f'candidate {i}: [{p[0]}, {p[1]}]': s for i, s, p in zip(pairs_id, sims, pairs)}
    return output


examples = get_example()[:15]
examples = [[','.join(i['stem'])] + [','.join(c) for c in i['choice'] + [''] * (6 - len(i['choice']))] for i in examples]
demo = gr.Interface(
    fn=greet,
    inputs=[
        gr.Textbox(lines=1, placeholder="Query Word Pair (separate by comma)"),
        gr.Textbox(lines=1, placeholder="Candidate Word Pair 1 (separate by comma)"),
        gr.Textbox(lines=1, placeholder="Candidate Word Pair 2 (separate by comma)"),
        gr.Textbox(lines=1, placeholder="Candidate Word Pair 3 (separate by comma)"),
        gr.Textbox(lines=1, placeholder="Candidate Word Pair 4 (separate by comma)"),
        gr.Textbox(lines=1, placeholder="Candidate Word Pair 5 (separate by comma)"),
        gr.Textbox(lines=1, placeholder="Candidate Word Pair 6 (separate by comma)"),
        ],
    outputs="label",
    examples=examples
)
demo.launch(show_error=True)