import json import requests import re from relbert import RelBERT import gradio as gr model = RelBERT(model='relbert/relbert-roberta-large') def get_example(): url = "https://huggingface.co/datasets/relbert/analogy_questions/raw/main/dataset/sat/test.jsonl" r = requests.get(url) example = [json.loads(i) for i in r.content.decode().split('\n') if len(i) > 0] return example def cosine_similarity(a, b, zero_vector_mask: float = -100): norm_a = sum(map(lambda x: x * x, a)) ** 0.5 norm_b = sum(map(lambda x: x * x, b)) ** 0.5 if norm_b * norm_a == 0: return zero_vector_mask return sum(map(lambda x: x[0] * x[1], zip(a, b)))/(norm_a * norm_b) def clean(text): text = re.sub(r"\A\s+", "", text) text = re.sub(r"\s+\Z", "", text) return text def greet( query, candidate_1, candidate_2, candidate_3, candidate_4, candidate_5, candidate_6): query = [clean(i) for i in query.split(',')] # validate query if len(query) == 0: raise ValueError(f'ERROR: query is empty {query}') if len(query) == 1: raise ValueError(f'ERROR: query contains single word {query}') if len(query) > 2: raise ValueError(f'ERROR: query contains more than two word {query}') pairs = [] pairs_id = [] for n, i in enumerate([ candidate_1, candidate_2, candidate_3, candidate_4, candidate_5, candidate_6 ]): if i == '': continue candidate = [clean(x) for x in i.split(',')] if len(candidate) == 1: raise ValueError(f'ERROR: candidate {n + 1} contains single word {candidate}') if len(candidate) > 2: raise ValueError(f'ERROR: candidate {n + 1} contains more than two word {candidate}') pairs.append(candidate) pairs_id.append(n+1) if len(pairs_id) < 2: raise ValueError(f'ERROR: please specify at least two candidates: {pairs}') vectors = model.get_embedding(pairs+[query]) vector_q = vectors.pop(-1) sims = [] for v in vectors: sims.append(cosine_similarity(v, vector_q)) output = {f'candidate {i}: [{p[0]}, {p[1]}]': s for i, s, p in zip(pairs_id, sims, pairs)} return output examples = get_example()[:15] examples = [[','.join(i['stem'])] + [','.join(c) for c in i['choice'] + [''] * (6 - len(i['choice']))] for i in examples] demo = gr.Interface( fn=greet, inputs=[ gr.Textbox(lines=1, placeholder="Query Word Pair (separate by comma)"), gr.Textbox(lines=1, placeholder="Candidate Word Pair 1 (separate by comma)"), gr.Textbox(lines=1, placeholder="Candidate Word Pair 2 (separate by comma)"), gr.Textbox(lines=1, placeholder="Candidate Word Pair 3 (separate by comma)"), gr.Textbox(lines=1, placeholder="Candidate Word Pair 4 (separate by comma)"), gr.Textbox(lines=1, placeholder="Candidate Word Pair 5 (separate by comma)"), gr.Textbox(lines=1, placeholder="Candidate Word Pair 6 (separate by comma)"), ], outputs="label", examples=examples ) demo.launch(show_error=True)