|
import gradio as gr |
|
import pandas as pd |
|
|
|
import torch |
|
from datasets import load_dataset |
|
from sentence_transformers import SentenceTransformer, util |
|
|
|
|
|
q_encoder = SentenceTransformer("checkpoints/q_encoder") |
|
doc_embeddings = torch.load('checkpoints/doc_embeddings.pt') |
|
docs = pd.DataFrame(load_dataset("antoiloui/bsard", data_files="articles_fr.csv")['train']) |
|
|
|
def search(query): |
|
q_emb = q_encoder.encode(query, convert_to_tensor=True) |
|
hits = util.semantic_search(q_emb, doc_embeddings, top_k=100, score_function=util.cos_sim)[0] |
|
return {docs.loc[h['corpus_id'], 'article']: f"Art. {docs.loc[h['corpus_id'], 'article_no']}, {docs.loc[h['corpus_id'], 'code']}" for h in hits[:5]} |
|
|
|
gr.Interface( |
|
fn=search, |
|
inputs=gr.Textbox(label="Question", placeholder=""), |
|
outputs=[gr.Textbox(lines=5, label="Result"),gr.Textbox(label="Reference")], |
|
title="Legislation Search 🇧🇪", |
|
description="", |
|
flagging_options=["👍","👎"], |
|
examples=["Qu'est-ce que je risque si je viole le secret professionnel ?", "Mon employeur peut-il me licencier alors que je suis malade ?"] |
|
).launch(share=False, enable_queue=False) |
|
|