import gradio as gr import pandas as pd import torch from datasets import load_dataset from sentence_transformers import SentenceTransformer, util, models q_encoder = SentenceTransformer(modules=[ models.Transformer(model_name_or_path="checkpoints/q_encoder", max_seq_length=512), models.Pooling(word_embedding_dimension=768, pooling_mode='cls'), ]) doc_embeddings = torch.load('checkpoints/doc_embeddings.pt', map_location=torch.device('cpu')) docs = pd.DataFrame(load_dataset("antoiloui/bsard", data_files="articles_fr.csv")['train']) def search(question): q_emb = q_encoder.encode(question, convert_to_tensor=True) hits = util.semantic_search(q_emb, doc_embeddings, top_k=100, score_function=util.cos_sim)[0] return {docs.loc[h['corpus_id'], 'article'] + '\n\n' + f"- Art. {docs.loc[h['corpus_id'], 'article_no']}, {docs.loc[h['corpus_id'], 'code']}" for h in hits[:5]} gr.Interface( fn=search, inputs=['text'], outputs=['textbox']*5, title="Legislation Search 🇧🇪", description="", allow_flagging="auto", examples=["Qu'est-ce que je risque si je viole le secret professionnel ?", "Mon employeur peut-il me licencier alors que je suis malade ?"] ).launch(share=False, enable_queue=False)