antoinelouis commited on
Commit
02ffc6e
1 Parent(s): ffea12f

Updated app file

Browse files
Files changed (1) hide show
  1. app.py +12 -9
app.py CHANGED
@@ -3,24 +3,27 @@ import pandas as pd
3
 
4
  import torch
5
  from datasets import load_dataset
6
- from sentence_transformers import SentenceTransformer, util
7
 
8
 
9
- q_encoder = SentenceTransformer("checkpoints/q_encoder")
10
- doc_embeddings = torch.load('checkpoints/doc_embeddings.pt')
 
 
 
11
  docs = pd.DataFrame(load_dataset("antoiloui/bsard", data_files="articles_fr.csv")['train'])
12
 
13
- def search(query):
14
- q_emb = q_encoder.encode(query, convert_to_tensor=True)
15
  hits = util.semantic_search(q_emb, doc_embeddings, top_k=100, score_function=util.cos_sim)[0]
16
- return {docs.loc[h['corpus_id'], 'article']: f"Art. {docs.loc[h['corpus_id'], 'article_no']}, {docs.loc[h['corpus_id'], 'code']}" for h in hits[:5]}
17
 
18
  gr.Interface(
19
  fn=search,
20
- inputs=gr.Textbox(label="Question", placeholder=""),
21
- outputs=[gr.Textbox(lines=5, label="Result"),gr.Textbox(label="Reference")],
22
  title="Legislation Search 🇧🇪",
23
  description="",
24
- flagging_options=["👍","👎"],
25
  examples=["Qu'est-ce que je risque si je viole le secret professionnel ?", "Mon employeur peut-il me licencier alors que je suis malade ?"]
26
  ).launch(share=False, enable_queue=False)
 
3
 
4
  import torch
5
  from datasets import load_dataset
6
+ from sentence_transformers import SentenceTransformer, util, models
7
 
8
 
9
+ q_encoder = SentenceTransformer(modules=[
10
+ models.Transformer(model_name_or_path="checkpoints/q_encoder", max_seq_length=512),
11
+ models.Pooling(word_embedding_dimension=768, pooling_mode='cls'),
12
+ ])
13
+ doc_embeddings = torch.load('checkpoints/doc_embeddings.pt', map_location=torch.device('cpu'))
14
  docs = pd.DataFrame(load_dataset("antoiloui/bsard", data_files="articles_fr.csv")['train'])
15
 
16
+ def search(question):
17
+ q_emb = q_encoder.encode(question, convert_to_tensor=True)
18
  hits = util.semantic_search(q_emb, doc_embeddings, top_k=100, score_function=util.cos_sim)[0]
19
+ return {docs.loc[h['corpus_id'], 'article'] + '\n\n' + f"- Art. {docs.loc[h['corpus_id'], 'article_no']}, {docs.loc[h['corpus_id'], 'code']}" for h in hits[:5]}
20
 
21
  gr.Interface(
22
  fn=search,
23
+ inputs=['text'],
24
+ outputs=['textbox']*5,
25
  title="Legislation Search 🇧🇪",
26
  description="",
27
+ allow_flagging="auto",
28
  examples=["Qu'est-ce que je risque si je viole le secret professionnel ?", "Mon employeur peut-il me licencier alors que je suis malade ?"]
29
  ).launch(share=False, enable_queue=False)