import json import requests from relbert import RelBERT import gradio as gr model = RelBERT(model='relbert/relbert-roberta-large') def get_example(): url = "https://huggingface.co/datasets/relbert/analogy_questions/raw/main/dataset/sat/test.jsonl" r = requests.get(url) example = [json.loads(i) for i in r.content.decode().split('\n') if len(i) > 0] return example def cosine_similarity(a, b, zero_vector_mask: float = -100): norm_a = sum(map(lambda x: x * x, a)) ** 0.5 norm_b = sum(map(lambda x: x * x, b)) ** 0.5 if norm_b * norm_a == 0: return zero_vector_mask return sum(map(lambda x: x[0] * x[1], zip(a, b)))/(norm_a * norm_b) def greet( query, candidate_1, candidate_2, candidate_3, candidate_4, candidate_5, candidate_6): query = query.split(',') # validate query if len(query) == 0: raise ValueError(f'ERROR: query is empty {query}') if len(query) == 1: raise ValueError(f'ERROR: query contains single word {query}') if len(query) > 2: raise ValueError(f'ERROR: query contains more than two word {query}') pairs = [] pairs_id = [] for n, i in enumerate([ candidate_1, candidate_2, candidate_3, candidate_4, candidate_5, candidate_6 ]): if i == '': continue candidate = i.split(',') if len(candidate) == 1: raise ValueError(f'ERROR: candidate {n + 1} contains single word {candidate}') if len(candidate) > 2: raise ValueError(f'ERROR: candidate {n + 1} contains more than two word {candidate}') pairs.append(candidate) pairs_id.append(n+1) if len(pairs_id) < 2: raise ValueError(f'ERROR: please specify at least two candidates: {pairs}') vectors = model.get_embedding(pairs+[query]) vector_q = vectors.pop(-1) sims = [] for v in vectors: sims.append(cosine_similarity(v, vector_q)) output = sorted(list(zip(pairs_id, sims, pairs)), key=lambda _x: _x[1], reverse=True) output = {f'candidate {n + 1}: [{p[0]}, {p[1]}]': s for n, (i, s, p) in enumerate(output)} return output examples = get_example()[:15] examples = [[','.join(i['stem'])] + [','.join(c) for c in i['choice'] + [''] * (6 - len(i['choice']))] for i in examples] demo = gr.Interface( fn=greet, inputs=[ gr.Textbox(lines=1, placeholder="Query Word Pair (separate by comma)"), gr.Textbox(lines=1, placeholder="Candidate Word Pair 1 (separate by comma)"), gr.Textbox(lines=1, placeholder="Candidate Word Pair 2 (separate by comma)"), gr.Textbox(lines=1, placeholder="Candidate Word Pair 3 (separate by comma)"), gr.Textbox(lines=1, placeholder="Candidate Word Pair 4 (separate by comma)"), gr.Textbox(lines=1, placeholder="Candidate Word Pair 5 (separate by comma)"), gr.Textbox(lines=1, placeholder="Candidate Word Pair 6 (separate by comma)"), ], outputs="label", examples=examples ) demo.launch(show_error=True)