eliot-hub commited on
Commit
4972318
1 Parent(s): 98ff3d8

first commit

Browse files
Files changed (2) hide show
  1. app.py +65 -0
  2. requirements.txt +4 -0
app.py ADDED
@@ -0,0 +1,65 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from sentence_transformers import SentenceTransformer
3
+ from transformers import AutoTokenizer, AutoModelForCausalLM
4
+ import vecs
5
+
6
+
7
+ # Load models
8
+ model_ret = SentenceTransformer("intfloat/multilingual-e5-large")
9
+
10
+ tokenizer = AutoTokenizer.from_pretrained("google/gemma-2b-it")
11
+ model_it = AutoModelForCausalLM.from_pretrained("google/gemma-2b-it") #, device_map="auto")
12
+
13
+ # Init session
14
+ user = SUPABASE_USER
15
+ password = SUPABASE_PASSWORD
16
+ host = "aws-0-eu-central-1.pooler.supabase.com"
17
+ port = "5432"
18
+ db_name = "postgres"
19
+
20
+ DB_CONNECTION = f"postgresql://{user}:{password}@{host}:{port}/{db_name}"
21
+
22
+ examples = [
23
+ "Comment fonctionne l'assurance emprunteur ?",
24
+ "Qu'est ce que l'euro croissance ?"
25
+ ]
26
+
27
+ def pipeline(query):
28
+ query_emb = model_ret.encode(query).tolist()
29
+
30
+ with vecs.create_client(DB_CONNECTION) as vx:
31
+
32
+ resume = vx.get_or_create_collection(
33
+ name="resume_vec",
34
+ dimension=1024,
35
+ )
36
+ result = resume.query(
37
+ data=query_emb,
38
+ limit=5,
39
+ measure="cosine_distance",
40
+ include_value=True,
41
+ include_metadata=True,
42
+ )
43
+ source_information = " ".join([e[2]["body"] for e in result])
44
+
45
+ combined_information = (
46
+ f"Requête: {query}\nRéponds à la requête en te basant sur le contexte suivant :\n{source_information}. \nRéponse:"
47
+ )
48
+
49
+ input_ids = tokenizer(combined_information, return_tensors="pt") #.to("cuda")
50
+ response = model_it.generate(**input_ids, max_new_tokens=500)
51
+ response_str = tokenizer.decode(response[0])
52
+ return response_str.split("Réponse:")[-1].strip().replace("<eos>", "")
53
+
54
+
55
+ demo = gr.Interface(
56
+ fn=pipeline,
57
+ inputs=gr.Textbox(label="Input", lines=2), #"text",
58
+ outputs=gr.Textbox(label="Output", lines=2),
59
+ title="RAG gemma-2b",
60
+ examples=examples,
61
+ allow_flagging="never",
62
+ theme=gr.themes.Default()
63
+ )
64
+
65
+ demo.launch(inbrowser=True) #share=True
requirements.txt ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ gradio==4.21.0
2
+ sentence-transformers==2.5.1
3
+ transformers @ git+https://github.com/huggingface/transformers.git
4
+ vecs==0.4.3