antoinelouis
commited on
Commit
•
02ffc6e
1
Parent(s):
ffea12f
Updated app file
Browse files
app.py
CHANGED
@@ -3,24 +3,27 @@ import pandas as pd
|
|
3 |
|
4 |
import torch
|
5 |
from datasets import load_dataset
|
6 |
-
from sentence_transformers import SentenceTransformer, util
|
7 |
|
8 |
|
9 |
-
q_encoder = SentenceTransformer(
|
10 |
-
|
|
|
|
|
|
|
11 |
docs = pd.DataFrame(load_dataset("antoiloui/bsard", data_files="articles_fr.csv")['train'])
|
12 |
|
13 |
-
def search(
|
14 |
-
q_emb = q_encoder.encode(
|
15 |
hits = util.semantic_search(q_emb, doc_embeddings, top_k=100, score_function=util.cos_sim)[0]
|
16 |
-
return {docs.loc[h['corpus_id'], 'article']
|
17 |
|
18 |
gr.Interface(
|
19 |
fn=search,
|
20 |
-
inputs=
|
21 |
-
outputs=[
|
22 |
title="Legislation Search 🇧🇪",
|
23 |
description="",
|
24 |
-
|
25 |
examples=["Qu'est-ce que je risque si je viole le secret professionnel ?", "Mon employeur peut-il me licencier alors que je suis malade ?"]
|
26 |
).launch(share=False, enable_queue=False)
|
|
|
3 |
|
4 |
import torch
|
5 |
from datasets import load_dataset
|
6 |
+
from sentence_transformers import SentenceTransformer, util, models
|
7 |
|
8 |
|
9 |
+
q_encoder = SentenceTransformer(modules=[
|
10 |
+
models.Transformer(model_name_or_path="checkpoints/q_encoder", max_seq_length=512),
|
11 |
+
models.Pooling(word_embedding_dimension=768, pooling_mode='cls'),
|
12 |
+
])
|
13 |
+
doc_embeddings = torch.load('checkpoints/doc_embeddings.pt', map_location=torch.device('cpu'))
|
14 |
docs = pd.DataFrame(load_dataset("antoiloui/bsard", data_files="articles_fr.csv")['train'])
|
15 |
|
16 |
+
def search(question):
|
17 |
+
q_emb = q_encoder.encode(question, convert_to_tensor=True)
|
18 |
hits = util.semantic_search(q_emb, doc_embeddings, top_k=100, score_function=util.cos_sim)[0]
|
19 |
+
return {docs.loc[h['corpus_id'], 'article'] + '\n\n' + f"- Art. {docs.loc[h['corpus_id'], 'article_no']}, {docs.loc[h['corpus_id'], 'code']}" for h in hits[:5]}
|
20 |
|
21 |
gr.Interface(
|
22 |
fn=search,
|
23 |
+
inputs=['text'],
|
24 |
+
outputs=['textbox']*5,
|
25 |
title="Legislation Search 🇧🇪",
|
26 |
description="",
|
27 |
+
allow_flagging="auto",
|
28 |
examples=["Qu'est-ce que je risque si je viole le secret professionnel ?", "Mon employeur peut-il me licencier alors que je suis malade ?"]
|
29 |
).launch(share=False, enable_queue=False)
|