Spaces:
Running
Running
Pclanglais
commited on
Commit
•
31559f1
1
Parent(s):
9df01d3
Update app.py
Browse files
app.py
CHANGED
@@ -13,8 +13,11 @@ from threading import Thread
|
|
13 |
from FlagEmbedding import BGEM3FlagModel
|
14 |
from sklearn.metrics.pairwise import cosine_similarity
|
15 |
|
|
|
|
|
16 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
17 |
|
|
|
18 |
embedding_model = BGEM3FlagModel('BAAI/bge-m3',
|
19 |
use_fp16=True) # Setting use_fp16 to True speeds up computation with a slight performance degradation
|
20 |
|
@@ -22,20 +25,44 @@ embeddings = np.load("embeddings_albert_tchap.npy")
|
|
22 |
embeddings_data = pd.read_json("embeddings_albert_tchap.json")
|
23 |
embeddings_text = embeddings_data["text_with_context"].tolist()
|
24 |
|
25 |
-
#
|
26 |
-
|
27 |
-
|
28 |
-
top_p=0.92
|
29 |
-
repetition_penalty=1.7
|
30 |
|
|
|
31 |
model_name = "Pclanglais/Tchap"
|
32 |
-
|
33 |
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
34 |
model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.float16)
|
35 |
model = model.to('cuda:0')
|
36 |
|
37 |
system_prompt = "<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n\nTu es Albert, l'agent conversationnel des services publics qui peut décrire des documents de référence ou aider à des tâches de rédaction<|eot_id|>"
|
38 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
39 |
#Vector search over the database
|
40 |
def vector_search(sentence_query):
|
41 |
|
@@ -71,9 +98,12 @@ class StopOnTokens(StoppingCriteria):
|
|
71 |
def predict(message, history):
|
72 |
|
73 |
global source_text
|
|
|
74 |
#For now, we only query the vector database once, at the start.
|
75 |
if len(history) == 0:
|
76 |
-
|
|
|
|
|
77 |
|
78 |
history_transformer_format = history + [[message, ""]]
|
79 |
|
@@ -87,7 +117,10 @@ def predict(message, history):
|
|
87 |
|
88 |
#Once we target the ongoing post we add the source.
|
89 |
if id_message == total_message:
|
90 |
-
|
|
|
|
|
|
|
91 |
else:
|
92 |
question = "<|start_header_id|>user<|end_header_id|>\n\n"+ item[0]
|
93 |
answer = "<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"+item[1]
|
|
|
13 |
from FlagEmbedding import BGEM3FlagModel
|
14 |
from sklearn.metrics.pairwise import cosine_similarity
|
15 |
|
16 |
+
from transformers import AutoModelForSequenceClassification
|
17 |
+
|
18 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
19 |
|
20 |
+
#Importing the embedding model
|
21 |
embedding_model = BGEM3FlagModel('BAAI/bge-m3',
|
22 |
use_fp16=True) # Setting use_fp16 to True speeds up computation with a slight performance degradation
|
23 |
|
|
|
25 |
embeddings_data = pd.read_json("embeddings_albert_tchap.json")
|
26 |
embeddings_text = embeddings_data["text_with_context"].tolist()
|
27 |
|
28 |
+
#Importing the classifier/router (deberta)
|
29 |
+
classifier_model = AutoModelForSequenceClassification.from_pretrained("AgentPublic/chatrag-deberta")
|
30 |
+
tokenizer = AutoTokenizer.from_pretrained("AgentPublic/chatrag-deberta")
|
|
|
|
|
31 |
|
32 |
+
#Importing the actual generative LLM (llama-based)
|
33 |
model_name = "Pclanglais/Tchap"
|
|
|
34 |
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
35 |
model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.float16)
|
36 |
model = model.to('cuda:0')
|
37 |
|
38 |
system_prompt = "<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n\nTu es Albert, l'agent conversationnel des services publics qui peut décrire des documents de référence ou aider à des tâches de rédaction<|eot_id|>"
|
39 |
|
40 |
+
#Function to guess whether we use the RAG or not.
|
41 |
+
def classification_chatrag(query):
|
42 |
+
encoding = tokenizer(query, return_tensors="pt")
|
43 |
+
encoding = {k: v.to(model_classifier.device) for k,v in encoding.items()}
|
44 |
+
|
45 |
+
outputs = model_classifier(**encoding)
|
46 |
+
|
47 |
+
logits = outputs.logits
|
48 |
+
logits.shape
|
49 |
+
|
50 |
+
# apply sigmoid + threshold
|
51 |
+
sigmoid = torch.nn.Sigmoid()
|
52 |
+
probs = sigmoid(logits.squeeze().cpu())
|
53 |
+
predictions = np.zeros(probs.shape)
|
54 |
+
|
55 |
+
# Extract the float value from the tensor
|
56 |
+
float_value = round(probs.item()*100)
|
57 |
+
|
58 |
+
if float_value > 50:
|
59 |
+
status = True
|
60 |
+
print("We activate RAG")
|
61 |
+
else:
|
62 |
+
status = False
|
63 |
+
print("We remove RAG")
|
64 |
+
return status
|
65 |
+
|
66 |
#Vector search over the database
|
67 |
def vector_search(sentence_query):
|
68 |
|
|
|
98 |
def predict(message, history):
|
99 |
|
100 |
global source_text
|
101 |
+
global assess_rag
|
102 |
#For now, we only query the vector database once, at the start.
|
103 |
if len(history) == 0:
|
104 |
+
assess_rag = classification_chatrag(message)
|
105 |
+
if assess_rag:
|
106 |
+
source_text = vector_search(message)
|
107 |
|
108 |
history_transformer_format = history + [[message, ""]]
|
109 |
|
|
|
117 |
|
118 |
#Once we target the ongoing post we add the source.
|
119 |
if id_message == total_message:
|
120 |
+
if assess_rag:
|
121 |
+
question = "<|start_header_id|>user<|end_header_id|>\n\n"+ item[0] + "\n\n### Source ###\n" + source_text
|
122 |
+
else:
|
123 |
+
question = "<|start_header_id|>user<|end_header_id|>\n\n"+ item[0]
|
124 |
else:
|
125 |
question = "<|start_header_id|>user<|end_header_id|>\n\n"+ item[0]
|
126 |
answer = "<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"+item[1]
|