Spaces:
Runtime error
Runtime error
first commit
Browse files- app.py +65 -0
- requirements.txt +4 -0
app.py
ADDED
@@ -0,0 +1,65 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
from sentence_transformers import SentenceTransformer
|
3 |
+
from transformers import AutoTokenizer, AutoModelForCausalLM
|
4 |
+
import vecs
|
5 |
+
|
6 |
+
|
7 |
+
# Load models
|
8 |
+
model_ret = SentenceTransformer("intfloat/multilingual-e5-large")
|
9 |
+
|
10 |
+
tokenizer = AutoTokenizer.from_pretrained("google/gemma-2b-it")
|
11 |
+
model_it = AutoModelForCausalLM.from_pretrained("google/gemma-2b-it") #, device_map="auto")
|
12 |
+
|
13 |
+
# Init session
|
14 |
+
user = SUPABASE_USER
|
15 |
+
password = SUPABASE_PASSWORD
|
16 |
+
host = "aws-0-eu-central-1.pooler.supabase.com"
|
17 |
+
port = "5432"
|
18 |
+
db_name = "postgres"
|
19 |
+
|
20 |
+
DB_CONNECTION = f"postgresql://{user}:{password}@{host}:{port}/{db_name}"
|
21 |
+
|
22 |
+
examples = [
|
23 |
+
"Comment fonctionne l'assurance emprunteur ?",
|
24 |
+
"Qu'est ce que l'euro croissance ?"
|
25 |
+
]
|
26 |
+
|
27 |
+
def pipeline(query):
|
28 |
+
query_emb = model_ret.encode(query).tolist()
|
29 |
+
|
30 |
+
with vecs.create_client(DB_CONNECTION) as vx:
|
31 |
+
|
32 |
+
resume = vx.get_or_create_collection(
|
33 |
+
name="resume_vec",
|
34 |
+
dimension=1024,
|
35 |
+
)
|
36 |
+
result = resume.query(
|
37 |
+
data=query_emb,
|
38 |
+
limit=5,
|
39 |
+
measure="cosine_distance",
|
40 |
+
include_value=True,
|
41 |
+
include_metadata=True,
|
42 |
+
)
|
43 |
+
source_information = " ".join([e[2]["body"] for e in result])
|
44 |
+
|
45 |
+
combined_information = (
|
46 |
+
f"Requête: {query}\nRéponds à la requête en te basant sur le contexte suivant :\n{source_information}. \nRéponse:"
|
47 |
+
)
|
48 |
+
|
49 |
+
input_ids = tokenizer(combined_information, return_tensors="pt") #.to("cuda")
|
50 |
+
response = model_it.generate(**input_ids, max_new_tokens=500)
|
51 |
+
response_str = tokenizer.decode(response[0])
|
52 |
+
return response_str.split("Réponse:")[-1].strip().replace("<eos>", "")
|
53 |
+
|
54 |
+
|
55 |
+
demo = gr.Interface(
|
56 |
+
fn=pipeline,
|
57 |
+
inputs=gr.Textbox(label="Input", lines=2), #"text",
|
58 |
+
outputs=gr.Textbox(label="Output", lines=2),
|
59 |
+
title="RAG gemma-2b",
|
60 |
+
examples=examples,
|
61 |
+
allow_flagging="never",
|
62 |
+
theme=gr.themes.Default()
|
63 |
+
)
|
64 |
+
|
65 |
+
demo.launch(inbrowser=True) #share=True
|
requirements.txt
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
gradio==4.21.0
|
2 |
+
sentence-transformers==2.5.1
|
3 |
+
transformers @ git+https://github.com/huggingface/transformers.git
|
4 |
+
vecs==0.4.3
|