File size: 1,249 Bytes
7473ba2 02ffc6e 7473ba2 02ffc6e 7473ba2 02ffc6e 7473ba2 02ffc6e 7473ba2 02ffc6e 826a31b 7473ba2 f0c67b3 7473ba2 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 |
import gradio as gr
import pandas as pd
import torch
from datasets import load_dataset
from sentence_transformers import SentenceTransformer, util, models
q_encoder = SentenceTransformer(modules=[
models.Transformer(model_name_or_path="checkpoints/q_encoder", max_seq_length=512),
models.Pooling(word_embedding_dimension=768, pooling_mode='cls'),
])
doc_embeddings = torch.load('checkpoints/doc_embeddings.pt', map_location=torch.device('cpu'))
docs = pd.DataFrame(load_dataset("antoiloui/bsard", data_files="articles_fr.csv")['train'])
def search(question):
q_emb = q_encoder.encode(question, convert_to_tensor=True)
hits = util.semantic_search(q_emb, doc_embeddings, top_k=100, score_function=util.cos_sim)[0]
return {docs.loc[h['corpus_id'], 'article'] + '\n\n' + f"- Art. {docs.loc[h['corpus_id'], 'article_no']}, {docs.loc[h['corpus_id'], 'code']}" for h in hits[:5]}
gr.Interface(
fn=search,
inputs=['text'],
outputs=['textbox']*5,
title="Belgian Legislation Search",
description="",
allow_flagging="never",
examples=["Qu'est-ce que je risque si je viole le secret professionnel ?", "Mon employeur peut-il me licencier alors que je suis malade ?"]
).launch(share=False, enable_queue=False)
|