File size: 1,144 Bytes
7473ba2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
import gradio as gr
import pandas as pd

import torch
from datasets import load_dataset
from sentence_transformers import SentenceTransformer, util


q_encoder = SentenceTransformer("checkpoints/q_encoder")
doc_embeddings = torch.load('checkpoints/doc_embeddings.pt')
docs = pd.DataFrame(load_dataset("antoiloui/bsard", data_files="articles_fr.csv")['train'])

def search(query):
    q_emb = q_encoder.encode(query, convert_to_tensor=True)
    hits = util.semantic_search(q_emb, doc_embeddings, top_k=100, score_function=util.cos_sim)[0]
    return {docs.loc[h['corpus_id'], 'article']: f"Art. {docs.loc[h['corpus_id'], 'article_no']}, {docs.loc[h['corpus_id'], 'code']}" for h in hits[:5]}

gr.Interface(
    fn=search,
    inputs=gr.Textbox(label="Question", placeholder=""),
    outputs=[gr.Textbox(lines=5, label="Result"),gr.Textbox(label="Reference")],
    title="Legislation Search 🇧🇪",
    description="",
    flagging_options=["👍","👎"],
    examples=["Qu'est-ce que je risque si je viole le secret professionnel ?", "Mon employeur peut-il me licencier alors que je suis malade ?"]
).launch(share=False, enable_queue=False)