adrien.aribaut-gaudin commited on
Commit
4de2d8b
1 Parent(s): 2ca433f

fix: changing control.py so that it doesnt execute unecessary functions and also changing llm.py to better prompts and better responses

Browse files
Files changed (2) hide show
  1. src/control/control.py +5 -2
  2. src/tools/llm.py +58 -38
src/control/control.py CHANGED
@@ -11,13 +11,16 @@ class Chatbot:
11
  def get_response(self, query, histo):
12
  histo_conversation, histo_queries = self._get_histo(histo)
13
  langage_of_query = self.llm.detect_language(query).lower()
14
- queries = self.llm.translate(text=histo_queries)
 
 
 
15
  block_sources = self.retriever.similarity_search(query=queries)
16
  block_sources = self._select_best_sources(block_sources)
17
  sources_contents = [s.content for s in block_sources]
18
  context = '\n'.join(sources_contents)
19
  answer = self.llm.generate_paragraph(query=queries, histo=histo_conversation, context=context, language=langage_of_query)
20
- answer = self.llm.generate_answer(answer=answer, query=query, histo=histo_conversation, context=context,language=langage_of_query)
21
  # print(answer.split('bot:')[1].strip())
22
  # print("*************")
23
  # answer = self._clean_answer(answer)
 
11
  def get_response(self, query, histo):
12
  histo_conversation, histo_queries = self._get_histo(histo)
13
  langage_of_query = self.llm.detect_language(query).lower()
14
+ if langage_of_query != "en":
15
+ queries = self.llm.translate(text=histo_queries)
16
+ else:
17
+ queries = histo_queries
18
  block_sources = self.retriever.similarity_search(query=queries)
19
  block_sources = self._select_best_sources(block_sources)
20
  sources_contents = [s.content for s in block_sources]
21
  context = '\n'.join(sources_contents)
22
  answer = self.llm.generate_paragraph(query=queries, histo=histo_conversation, context=context, language=langage_of_query)
23
+ # answer = self.llm.generate_answer(answer=answer, query=query, histo=histo_conversation, context=context,language=langage_of_query)
24
  # print(answer.split('bot:')[1].strip())
25
  # print("*************")
26
  # answer = self._clean_answer(answer)
src/tools/llm.py CHANGED
@@ -12,13 +12,14 @@ class LlmAgent:
12
  device_map="cuda:0",
13
  trust_remote_code=False, #A CHANGER SELON LES MODELES, POUR CELUI DE LAMA2 CA MARCHE (celui par default)
14
  revision="main",token=token)
15
- self.pipe = pipeline("text-generation", model=self.model, tokenizer=self.tokenizer,torch_dtype=torch.float16)
16
 
17
  def generate_paragraph(self, query: str, context: {}, histo: [(str, str)], language='fr') -> str:
18
  torch.cuda.empty_cache()
19
- locallm = HuggingFacePipeline(pipeline=self.pipe)
20
  """generates the answer"""
21
  template = (f'''[INST] <<SYS>>"
 
 
22
  "You are a conversation bot designed to answer to the query from users"
23
  "Your answer is based on the context delimited by triple backticks: "
24
  "\\n ``` {context} ```\\n"
@@ -26,43 +27,48 @@ class LlmAgent:
26
  "delimited by triple backticks: "
27
  "\\n ``` {histo} ```\\n"
28
  "Your response shall be in {language} and shall be concise"
 
 
29
  "\\n <</SYS>>"
30
  "\\n {query}[/INST]''')
31
- prompt = PromptTemplate(input_variables=[], template=template)
32
- llm_chain = LLMChain(prompt=prompt,llm=locallm)
33
- p = llm_chain.predict()
34
  # print("****************")
35
  # print(template)
36
  # print("----")
37
- # print(p)
38
- return p
 
39
 
40
  def translate(self, text: str, language="en") -> str:
41
  torch.cuda.empty_cache()
42
- locallm = HuggingFacePipeline(pipeline=self.pipe)
43
  """translates"""
44
 
45
  # languages = "`French to English" if language == "en" else "English to French"
46
 
47
- tempate = (f'''[INST] <<SYS>>
48
- Your task consists in translating in English\\n"
 
 
49
  the following text:
50
  <</SYS>>
51
  {text}[/INST]'''
52
  )
53
 
54
- prompt = PromptTemplate(input_variables=[], template=tempate)
55
- llm_chain = LLMChain(prompt=prompt,llm=locallm,verbose=True)
56
- p = llm_chain.predict()
57
- return p
 
 
58
 
59
  def generate_answer(self, query: str, answer: str, histo: str, context: str,language : str) -> str:
60
  torch.cuda.empty_cache()
61
  """provides the final answer in {language} based on the initial query and the answer in english"""
62
  def _cut_unfinished_sentence(s: str):
63
  return '.'.join(s.split('.')[:-1])
64
- locallm = HuggingFacePipeline(pipeline=self.pipe)
65
  template = (f'''[INST] <<SYS>>
 
 
66
  Your task consists in translating the answer in {language}, if its not already the case, to the query "
67
  delimited by triple backticks: ```{query}``` \\n"
68
  \\n You don't add new content to the answer but: "
@@ -70,21 +76,25 @@ class LlmAgent:
70
  ```{context}```"
71
  \\n 2 You are consistent and avoid redundancies with the rest of the initial"
72
  conversation delimited by triple backticks: ```{histo}```"
 
 
73
  You are given the answer in {language}:
74
  <</SYS>>
75
  {answer}[/INST]'''
76
  )
77
- prompt = PromptTemplate(input_variables=[], template=template)
78
- llm_chain = LLMChain(prompt=prompt,llm=locallm,verbose=True)
79
- p = llm_chain.predict()
80
- # p = _cut_unfinished_sentence(p)
81
- return p
 
 
82
 
83
 
84
  def transform_parahraph_into_question(self, prompt : str, title_doc : str = '',title_para : str = '') -> str:
85
  torch.cuda.empty_cache()
86
- self.tokenizer.pad_token = self.tokenizer.eos_token
87
- max_tokens = 45
88
 
89
  prompt_template=f'''[INST] <<SYS>>
90
  You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature.
@@ -92,35 +102,45 @@ class LlmAgent:
92
  Your job is to create a question about a paragraph of a document untitled "{title_doc}".
93
  The paragraph title is "{title_para}".
94
  If you see that the question that you are creating will not respect {max_tokens} tokens, find a way to make it shorter.
95
- If you see that the document paragraph seems to be code flattened, try to analyze it and create a question about it.
96
- If you see that the paragraph is a table, try to create a question about it.
97
  If you can't create a question about the paragraph, just rephrase {title_para} so that it becomes a question.
98
- Your response shall only contains one question, shall be concise and shall respect the following format:
99
- "Question: <question>"
 
100
  The paragraph you need to create a question about is the following :
101
  <</SYS>>
102
  {prompt}[/INST]
103
 
104
  '''
105
- input_ids = self.tokenizer(prompt_template, return_tensors='pt').input_ids.cuda()
106
- output = self.model.generate(inputs=input_ids, temperature=0.7, do_sample=True, top_p=0.95, top_k=40, max_new_tokens=max_tokens,num_return_sequences=1)
107
-
108
- res1 = self.tokenizer.decode(output[0][input_ids.shape[-1]:], skip_special_tokens=True)
109
- print(res1)
110
- print("-"*len(res1))
111
- return res1
 
 
 
 
 
 
112
 
113
  def detect_language(self, text: str) -> str:
114
  torch.cuda.empty_cache()
115
  """detects the language"""
116
- locallm = HuggingFacePipeline(pipeline=self.pipe)
117
  template = (f'''[INST] <<SYS>>
 
 
118
  Your task consists in detecting the language of the user query"
119
  Your answer shall be the two letters code of the language"
 
 
120
  \\n <</SYS>>"
121
  \\n {text}[/INST]'''
122
  )
123
- prompt = PromptTemplate(input_variables=[], template=template)
124
- llm_chain = LLMChain(prompt=prompt,llm=locallm,verbose=True)
125
- p = llm_chain.predict()
126
- return p
 
 
 
12
  device_map="cuda:0",
13
  trust_remote_code=False, #A CHANGER SELON LES MODELES, POUR CELUI DE LAMA2 CA MARCHE (celui par default)
14
  revision="main",token=token)
15
+ self.pipe = pipeline("text-generation", model=self.model, tokenizer=self.tokenizer,torch_dtype=torch.float16,max_new_tokens=256,repetition_penalty=1.1,top_k=40,top_p=0.95,temperature=0.7,do_sample=True,return_full_text=False)
16
 
17
  def generate_paragraph(self, query: str, context: {}, histo: [(str, str)], language='fr') -> str:
18
  torch.cuda.empty_cache()
 
19
  """generates the answer"""
20
  template = (f'''[INST] <<SYS>>"
21
+ You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature.
22
+ If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information.
23
  "You are a conversation bot designed to answer to the query from users"
24
  "Your answer is based on the context delimited by triple backticks: "
25
  "\\n ``` {context} ```\\n"
 
27
  "delimited by triple backticks: "
28
  "\\n ``` {histo} ```\\n"
29
  "Your response shall be in {language} and shall be concise"
30
+ You should respect the following format: "
31
+ <response>"
32
  "\\n <</SYS>>"
33
  "\\n {query}[/INST]''')
34
+ pipe = self.pipe(template)
 
 
35
  # print("****************")
36
  # print(template)
37
  # print("----")
38
+ res = pipe[0]["generated_text"]
39
+ print(res)
40
+ return res
41
 
42
  def translate(self, text: str, language="en") -> str:
43
  torch.cuda.empty_cache()
 
44
  """translates"""
45
 
46
  # languages = "`French to English" if language == "en" else "English to French"
47
 
48
+ template = (f'''[INST] <<SYS>>
49
+ You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature.
50
+ If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information.
51
+ Your task consists in translating in English\\n"
52
  the following text:
53
  <</SYS>>
54
  {text}[/INST]'''
55
  )
56
 
57
+ pipe = self.pipe(template)
58
+ # print("****************")
59
+ # print(template)
60
+ # print("----")
61
+ res = pipe[0]["generated_text"]
62
+ return res
63
 
64
  def generate_answer(self, query: str, answer: str, histo: str, context: str,language : str) -> str:
65
  torch.cuda.empty_cache()
66
  """provides the final answer in {language} based on the initial query and the answer in english"""
67
  def _cut_unfinished_sentence(s: str):
68
  return '.'.join(s.split('.')[:-1])
 
69
  template = (f'''[INST] <<SYS>>
70
+ You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature.
71
+ If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information.
72
  Your task consists in translating the answer in {language}, if its not already the case, to the query "
73
  delimited by triple backticks: ```{query}``` \\n"
74
  \\n You don't add new content to the answer but: "
 
76
  ```{context}```"
77
  \\n 2 You are consistent and avoid redundancies with the rest of the initial"
78
  conversation delimited by triple backticks: ```{histo}```"
79
+ Your response shall respect the following format: "
80
+ <response>"
81
  You are given the answer in {language}:
82
  <</SYS>>
83
  {answer}[/INST]'''
84
  )
85
+ pipe = self.pipe(template)
86
+ # print("****************")
87
+ # print(template)
88
+ # print("----")
89
+ res = pipe[0]["generated_text"]
90
+ print(res)
91
+ return res
92
 
93
 
94
  def transform_parahraph_into_question(self, prompt : str, title_doc : str = '',title_para : str = '') -> str:
95
  torch.cuda.empty_cache()
96
+ max_tokens = 80
97
+ pipeline_modified = pipeline("text-generation", model=self.model, tokenizer=self.tokenizer,torch_dtype=torch.float16,max_new_tokens=max_tokens,repetition_penalty=1.1,top_k=40,top_p=0.95,temperature=0.7,do_sample=True,return_full_text=False)
98
 
99
  prompt_template=f'''[INST] <<SYS>>
100
  You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature.
 
102
  Your job is to create a question about a paragraph of a document untitled "{title_doc}".
103
  The paragraph title is "{title_para}".
104
  If you see that the question that you are creating will not respect {max_tokens} tokens, find a way to make it shorter.
 
 
105
  If you can't create a question about the paragraph, just rephrase {title_para} so that it becomes a question.
106
+ Your response shall contains two questions, shall be concise and shall respect the following format:
107
+ "Question: <question1>\\nQuestion: <question2>\\n"
108
+ You should not answer to the question, just create it.
109
  The paragraph you need to create a question about is the following :
110
  <</SYS>>
111
  {prompt}[/INST]
112
 
113
  '''
114
+ pipe = pipeline_modified(prompt_template)
115
+ # print("****************")
116
+ # print(template)
117
+ # print("----")
118
+ #filter the answer to only keep the question
119
+ res = pipe[0]["generated_text"]
120
+ # res = res.split("Question: ")
121
+ # res1 = res[1]
122
+ # res2 = res[2]
123
+ # print(res1)
124
+ # print(res2)
125
+ print(res)
126
+ return res
127
 
128
  def detect_language(self, text: str) -> str:
129
  torch.cuda.empty_cache()
130
  """detects the language"""
 
131
  template = (f'''[INST] <<SYS>>
132
+ You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature.
133
+ If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information.
134
  Your task consists in detecting the language of the user query"
135
  Your answer shall be the two letters code of the language"
136
+ and should respect the following format: "
137
+ <code>"
138
  \\n <</SYS>>"
139
  \\n {text}[/INST]'''
140
  )
141
+ pipe = self.pipe(template)
142
+ # print("****************")
143
+ # print(template)
144
+ # print("----")
145
+ res = pipe[0]["generated_text"]
146
+ return res