Pclanglais commited on
Commit
31559f1
1 Parent(s): 9df01d3

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +41 -8
app.py CHANGED
@@ -13,8 +13,11 @@ from threading import Thread
13
  from FlagEmbedding import BGEM3FlagModel
14
  from sklearn.metrics.pairwise import cosine_similarity
15
 
 
 
16
  device = "cuda" if torch.cuda.is_available() else "cpu"
17
 
 
18
  embedding_model = BGEM3FlagModel('BAAI/bge-m3',
19
  use_fp16=True) # Setting use_fp16 to True speeds up computation with a slight performance degradation
20
 
@@ -22,20 +25,44 @@ embeddings = np.load("embeddings_albert_tchap.npy")
22
  embeddings_data = pd.read_json("embeddings_albert_tchap.json")
23
  embeddings_text = embeddings_data["text_with_context"].tolist()
24
 
25
- # Define the device
26
- temperature=0.2
27
- max_new_tokens=1000
28
- top_p=0.92
29
- repetition_penalty=1.7
30
 
 
31
  model_name = "Pclanglais/Tchap"
32
-
33
  tokenizer = AutoTokenizer.from_pretrained(model_name)
34
  model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.float16)
35
  model = model.to('cuda:0')
36
 
37
  system_prompt = "<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n\nTu es Albert, l'agent conversationnel des services publics qui peut décrire des documents de référence ou aider à des tâches de rédaction<|eot_id|>"
38
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
39
  #Vector search over the database
40
  def vector_search(sentence_query):
41
 
@@ -71,9 +98,12 @@ class StopOnTokens(StoppingCriteria):
71
  def predict(message, history):
72
 
73
  global source_text
 
74
  #For now, we only query the vector database once, at the start.
75
  if len(history) == 0:
76
- source_text = vector_search(message)
 
 
77
 
78
  history_transformer_format = history + [[message, ""]]
79
 
@@ -87,7 +117,10 @@ def predict(message, history):
87
 
88
  #Once we target the ongoing post we add the source.
89
  if id_message == total_message:
90
- question = "<|start_header_id|>user<|end_header_id|>\n\n"+ item[0] + "\n\n### Source ###\n" + source_text
 
 
 
91
  else:
92
  question = "<|start_header_id|>user<|end_header_id|>\n\n"+ item[0]
93
  answer = "<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"+item[1]
 
13
  from FlagEmbedding import BGEM3FlagModel
14
  from sklearn.metrics.pairwise import cosine_similarity
15
 
16
+ from transformers import AutoModelForSequenceClassification
17
+
18
  device = "cuda" if torch.cuda.is_available() else "cpu"
19
 
20
+ #Importing the embedding model
21
  embedding_model = BGEM3FlagModel('BAAI/bge-m3',
22
  use_fp16=True) # Setting use_fp16 to True speeds up computation with a slight performance degradation
23
 
 
25
  embeddings_data = pd.read_json("embeddings_albert_tchap.json")
26
  embeddings_text = embeddings_data["text_with_context"].tolist()
27
 
28
+ #Importing the classifier/router (deberta)
29
+ classifier_model = AutoModelForSequenceClassification.from_pretrained("AgentPublic/chatrag-deberta")
30
+ tokenizer = AutoTokenizer.from_pretrained("AgentPublic/chatrag-deberta")
 
 
31
 
32
+ #Importing the actual generative LLM (llama-based)
33
  model_name = "Pclanglais/Tchap"
 
34
  tokenizer = AutoTokenizer.from_pretrained(model_name)
35
  model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.float16)
36
  model = model.to('cuda:0')
37
 
38
  system_prompt = "<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n\nTu es Albert, l'agent conversationnel des services publics qui peut décrire des documents de référence ou aider à des tâches de rédaction<|eot_id|>"
39
 
40
+ #Function to guess whether we use the RAG or not.
41
+ def classification_chatrag(query):
42
+ encoding = tokenizer(query, return_tensors="pt")
43
+ encoding = {k: v.to(model_classifier.device) for k,v in encoding.items()}
44
+
45
+ outputs = model_classifier(**encoding)
46
+
47
+ logits = outputs.logits
48
+ logits.shape
49
+
50
+ # apply sigmoid + threshold
51
+ sigmoid = torch.nn.Sigmoid()
52
+ probs = sigmoid(logits.squeeze().cpu())
53
+ predictions = np.zeros(probs.shape)
54
+
55
+ # Extract the float value from the tensor
56
+ float_value = round(probs.item()*100)
57
+
58
+ if float_value > 50:
59
+ status = True
60
+ print("We activate RAG")
61
+ else:
62
+ status = False
63
+ print("We remove RAG")
64
+ return status
65
+
66
  #Vector search over the database
67
  def vector_search(sentence_query):
68
 
 
98
  def predict(message, history):
99
 
100
  global source_text
101
+ global assess_rag
102
  #For now, we only query the vector database once, at the start.
103
  if len(history) == 0:
104
+ assess_rag = classification_chatrag(message)
105
+ if assess_rag:
106
+ source_text = vector_search(message)
107
 
108
  history_transformer_format = history + [[message, ""]]
109
 
 
117
 
118
  #Once we target the ongoing post we add the source.
119
  if id_message == total_message:
120
+ if assess_rag:
121
+ question = "<|start_header_id|>user<|end_header_id|>\n\n"+ item[0] + "\n\n### Source ###\n" + source_text
122
+ else:
123
+ question = "<|start_header_id|>user<|end_header_id|>\n\n"+ item[0]
124
  else:
125
  question = "<|start_header_id|>user<|end_header_id|>\n\n"+ item[0]
126
  answer = "<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"+item[1]