KonradSzafer commited on
Commit
19fc5ee
·
1 Parent(s): 34ac5d3

llama3 chatQA

Browse files
.github/workflows/tests-integration.yml CHANGED
@@ -30,6 +30,7 @@ jobs:
30
 
31
  - name: Install dependencies
32
  run: |
 
33
  pip install --no-cache-dir -r requirements.txt
34
  cp config/.env.example config/.env
35
  - name: Run unit tests
 
30
 
31
  - name: Install dependencies
32
  run: |
33
+ pip3 install --upgrade pip
34
  pip install --no-cache-dir -r requirements.txt
35
  cp config/.env.example config/.env
36
  - name: Run unit tests
config/prompt_templates/llama3_chat.txt ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ System: This is a chat between a user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions based on the context. The assistant should also indicate when the answer cannot be found in the context.
2
+
3
+ {context}
4
+
5
+ User: {question} Please give a full and complete answer for the question.
6
+
7
+ Assistant:
config/prompt_templates/phi3.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ <|user|>\n Try to be as factual as possible. Answer a question from a given context. CONTEXT: {context} \n QUESTION: {question} <|end|>\n<|assistant|>
qa_engine/config.py CHANGED
@@ -11,7 +11,7 @@ def get_env(env_name: str, default: Any = None, warn: bool = True) -> str:
11
  if default is not None:
12
  if warn:
13
  logger.warning(
14
- f'Environment variable {env_name} not found.' \
15
  f'Using the default value: {default}.'
16
  )
17
  return default
 
11
  if default is not None:
12
  if warn:
13
  logger.warning(
14
+ f'Environment variable {env_name} not found. ' \
15
  f'Using the default value: {default}.'
16
  )
17
  return default
qa_engine/logger.py CHANGED
@@ -2,6 +2,8 @@ import logging
2
 
3
 
4
  logger = logging.getLogger(__name__)
 
 
5
 
6
  def setup_logger() -> None:
7
  """
 
2
 
3
 
4
  logger = logging.getLogger(__name__)
5
+ logging.getLogger('discord').setLevel(logging.ERROR)
6
+ logging.getLogger('discord.gateway').setLevel(logging.ERROR)
7
 
8
  def setup_logger() -> None:
9
  """
qa_engine/qa_engine.py CHANGED
@@ -1,19 +1,11 @@
1
- import os
2
  import re
3
- import json
4
- import requests
5
- import subprocess
6
  from typing import Mapping, Optional, Any
7
 
8
  import torch
9
  import transformers
10
  from transformers import AutoTokenizer, AutoModelForCausalLM
11
  from huggingface_hub import snapshot_download
12
- from urllib.parse import quote
13
- from langchain import PromptTemplate, HuggingFaceHub, LLMChain
14
- from langchain.llms import HuggingFacePipeline
15
- from langchain.llms.base import LLM
16
- from langchain.embeddings import HuggingFaceEmbeddings, HuggingFaceHubEmbeddings, HuggingFaceInstructEmbeddings
17
  from langchain.vectorstores import FAISS
18
  from sentence_transformers import CrossEncoder
19
 
@@ -22,41 +14,7 @@ from qa_engine.response import Response
22
  from qa_engine.mocks import MockLocalBinaryModel
23
 
24
 
25
- class LocalBinaryModel(LLM):
26
- model_id: str = None
27
- model_path: str = None
28
- llm: None = None
29
-
30
- def __init__(self, config: Config):
31
- super().__init__()
32
- # pip install llama_cpp_python==0.1.39
33
- from llama_cpp import Llama
34
-
35
- self.model_id = config.question_answering_model_id
36
- self.model_path = f'qa_engine/{self.model_id}'
37
- if not os.path.exists(self.model_path):
38
- raise ValueError(f'{self.model_path} does not exist')
39
- self.llm = Llama(model_path=self.model_path, n_ctx=4096)
40
-
41
- def _call(self, prompt: str, stop: Optional[list[str]] = None) -> str:
42
- output = self.llm(
43
- prompt,
44
- max_tokens=1024,
45
- stop=['Q:'],
46
- echo=False
47
- )
48
- return output['choices'][0]['text']
49
-
50
- @property
51
- def _identifying_params(self) -> Mapping[str, Any]:
52
- return {'name_of_model': self.model_id}
53
-
54
- @property
55
- def _llm_type(self) -> str:
56
- return self.model_id
57
-
58
-
59
- class TransformersPipelineModel(LLM):
60
  model_id: str = None
61
  min_new_tokens: int = None
62
  max_new_tokens: int = None
@@ -64,7 +22,8 @@ class TransformersPipelineModel(LLM):
64
  top_k: int = None
65
  top_p: float = None
66
  do_sample: bool = None
67
- pipeline: str = None
 
68
 
69
  def __init__(self, config: Config):
70
  super().__init__()
@@ -76,35 +35,32 @@ class TransformersPipelineModel(LLM):
76
  self.top_p = config.top_p
77
  self.do_sample = config.do_sample
78
 
79
- tokenizer = AutoTokenizer.from_pretrained(self.model_id)
80
- model = AutoModelForCausalLM.from_pretrained(
81
  self.model_id,
82
- torch_dtype=torch.bfloat16,
83
- trust_remote_code=True,
84
- load_in_8bit=False,
85
- device_map='auto',
86
- resume_download=True,
87
  )
88
- self.pipeline = transformers.pipeline(
89
- 'text-generation',
90
- model=model,
91
- tokenizer=tokenizer,
92
- torch_dtype=torch.bfloat16,
93
- device_map='auto',
94
- eos_token_id=tokenizer.eos_token_id,
95
- pad_token_id=tokenizer.eos_token_id,
 
 
 
 
 
96
  min_new_tokens=self.min_new_tokens,
97
  max_new_tokens=self.max_new_tokens,
98
- temperature=self.temperature,
99
- top_k=self.top_k,
100
- top_p=self.top_p,
101
- do_sample=self.do_sample,
102
  )
103
-
104
- def _call(self, prompt: str, stop: Optional[list[str]] = None) -> str:
105
- output_text = self.pipeline(prompt)[0]['generated_text']
106
- output_text = output_text.replace(prompt+'\n', '')
107
- return output_text
108
 
109
  @property
110
  def _identifying_params(self) -> Mapping[str, Any]:
@@ -115,39 +71,6 @@ class TransformersPipelineModel(LLM):
115
  return self.model_id
116
 
117
 
118
- class APIServedModel(LLM):
119
- model_url: str = None
120
- debug: bool = None
121
-
122
- def __init__(self, model_url: str, debug: bool = False):
123
- super().__init__()
124
- if model_url[-1] == '/':
125
- raise ValueError('URL should not end with a slash - "/"')
126
- self.model_url = model_url
127
- self.debug = debug
128
-
129
- def _call(self, prompt: str, stop: Optional[list[str]] = None) -> str:
130
- prompt_encoded = quote(prompt, safe='')
131
- url = f'{self.model_url}/?prompt={prompt_encoded}'
132
- if self.debug:
133
- logger.info(f'URL: {url}')
134
- try:
135
- response = requests.get(url, timeout=1200, verify=False)
136
- response.raise_for_status()
137
- return json.loads(response.content)['output_text']
138
- except Exception as err:
139
- logger.error(f'Error: {err}')
140
- return f'Error: {err}'
141
-
142
- @property
143
- def _identifying_params(self) -> Mapping[str, Any]:
144
- return {'name_of_model': f'model url: {self.model_url}'}
145
-
146
- @property
147
- def _llm_type(self) -> str:
148
- return 'api_model'
149
-
150
-
151
  class QAEngine():
152
  """
153
  QAEngine class, used for generating answers to questions.
@@ -163,16 +86,10 @@ class QAEngine():
163
  self.num_relevant_docs=config.num_relevant_docs
164
  self.add_sources_to_response=config.add_sources_to_response
165
  self.use_messages_for_context=config.use_messages_in_context
166
- self.debug=config.debug
167
-
168
  self.first_stage_docs: int = 50
169
 
170
- prompt = PromptTemplate(
171
- template=self.prompt_template,
172
- input_variables=['question', 'context']
173
- )
174
  self.llm_model = self._get_model()
175
- self.llm_chain = LLMChain(prompt=prompt, llm=self.llm_model)
176
 
177
  if self.use_docs_for_context:
178
  logger.info(f'Downloading {self.index_repo_id}')
@@ -196,29 +113,39 @@ class QAEngine():
196
 
197
 
198
  def _get_model(self):
199
- if 'local_models/' in self.question_answering_model_id:
200
- logger.info('using local binary model')
201
- return LocalBinaryModel(self.config)
202
- elif 'api_models/' in self.question_answering_model_id:
203
- logger.info('using api served model')
204
- return APIServedModel(
205
- model_url=self.question_answering_model_id.replace('api_models/', ''),
206
- debug=self.debug
207
- )
208
- elif self.question_answering_model_id == 'mock':
209
- logger.info('using mock model')
210
  return MockLocalBinaryModel()
211
  else:
212
  logger.info('using transformers pipeline model')
213
- return TransformersPipelineModel(self.config)
214
-
215
-
216
  @staticmethod
217
- def _preprocess_question(question: str) -> str:
218
  if '?' not in question:
219
  question += '?'
220
- return question
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
221
 
 
 
222
 
223
  @staticmethod
224
  def _postprocess_answer(answer: str) -> str:
@@ -289,8 +216,8 @@ class QAEngine():
289
  response.set_sources(sources=[str(m['source']) for m in metadata])
290
 
291
  logger.info('Running LLM chain')
292
- question_processed = QAEngine._preprocess_question(question)
293
- answer = self.llm_chain.run(question=question_processed, context=context)
294
  answer_postprocessed = QAEngine._postprocess_answer(answer)
295
  response.set_answer(answer_postprocessed)
296
  logger.info('Received answer')
 
 
1
  import re
 
 
 
2
  from typing import Mapping, Optional, Any
3
 
4
  import torch
5
  import transformers
6
  from transformers import AutoTokenizer, AutoModelForCausalLM
7
  from huggingface_hub import snapshot_download
8
+ from langchain.embeddings import HuggingFaceInstructEmbeddings
 
 
 
 
9
  from langchain.vectorstores import FAISS
10
  from sentence_transformers import CrossEncoder
11
 
 
14
  from qa_engine.mocks import MockLocalBinaryModel
15
 
16
 
17
+ class HuggingFaceModel:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18
  model_id: str = None
19
  min_new_tokens: int = None
20
  max_new_tokens: int = None
 
22
  top_k: int = None
23
  top_p: float = None
24
  do_sample: bool = None
25
+ tokenizer: transformers.PreTrainedTokenizer = None
26
+ model: transformers.PreTrainedModel = None
27
 
28
  def __init__(self, config: Config):
29
  super().__init__()
 
35
  self.top_p = config.top_p
36
  self.do_sample = config.do_sample
37
 
38
+ self.tokenizer = AutoTokenizer.from_pretrained(self.model_id)
39
+ self.model = AutoModelForCausalLM.from_pretrained(
40
  self.model_id,
41
+ torch_dtype=torch.float16,
42
+ device_map="auto"
 
 
 
43
  )
44
+
45
+ def _call(self, prompt: str, stop: Optional[list[str]] = None) -> str:
46
+ tokenized_prompt = self.tokenizer(
47
+ self.tokenizer.bos_token + prompt,
48
+ return_tensors="pt"
49
+ ).to(self.model.device)
50
+ terminators = [
51
+ self.tokenizer.eos_token_id,
52
+ self.tokenizer.convert_tokens_to_ids("<|eot_id|>")
53
+ ]
54
+ outputs = self.model.generate(
55
+ input_ids=tokenized_prompt.input_ids,
56
+ attention_mask=tokenized_prompt.attention_mask,
57
  min_new_tokens=self.min_new_tokens,
58
  max_new_tokens=self.max_new_tokens,
59
+ eos_token_id=terminators
 
 
 
60
  )
61
+ response = outputs[0][tokenized_prompt.input_ids.shape[-1]:]
62
+ decoded_response = self.tokenizer.decode(response, skip_special_tokens=True)
63
+ return decoded_response
 
 
64
 
65
  @property
66
  def _identifying_params(self) -> Mapping[str, Any]:
 
71
  return self.model_id
72
 
73
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
74
  class QAEngine():
75
  """
76
  QAEngine class, used for generating answers to questions.
 
86
  self.num_relevant_docs=config.num_relevant_docs
87
  self.add_sources_to_response=config.add_sources_to_response
88
  self.use_messages_for_context=config.use_messages_in_context
89
+ self.debug=config.debug
 
90
  self.first_stage_docs: int = 50
91
 
 
 
 
 
92
  self.llm_model = self._get_model()
 
93
 
94
  if self.use_docs_for_context:
95
  logger.info(f'Downloading {self.index_repo_id}')
 
113
 
114
 
115
  def _get_model(self):
116
+ if self.question_answering_model_id == 'mock':
117
+ logger.warn('using mock model')
 
 
 
 
 
 
 
 
 
118
  return MockLocalBinaryModel()
119
  else:
120
  logger.info('using transformers pipeline model')
121
+ return HuggingFaceModel(self.config)
122
+
 
123
  @staticmethod
124
+ def _preprocess_input(question: str, context: str) -> str:
125
  if '?' not in question:
126
  question += '?'
127
+
128
+ # llama3 chatQA specific
129
+ messages = [
130
+ {"role": "user", "content": question}
131
+ ]
132
+
133
+ system = "System: This is a chat between a user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions based on the context. The assistant should also indicate when the answer cannot be found in the context."
134
+ instruction = "Please give a full and complete answer for the question."
135
+
136
+ for item in messages:
137
+ if item['role'] == "user":
138
+ ## only apply this instruction for the first user turn
139
+ item['content'] = instruction + " " + item['content']
140
+ break
141
+
142
+ conversation = '\n\n'.join([
143
+ "User: " + item["content"] if item["role"] == "user" else
144
+ "Assistant: " + item["content"] for item in messages
145
+ ]) + "\n\nAssistant:"
146
 
147
+ inputs = system + "\n\n" + context + "\n\n" + conversation
148
+ return inputs
149
 
150
  @staticmethod
151
  def _postprocess_answer(answer: str) -> str:
 
216
  response.set_sources(sources=[str(m['source']) for m in metadata])
217
 
218
  logger.info('Running LLM chain')
219
+ inputs = QAEngine._preprocess_input(question, context)
220
+ answer = self.llm_model._call(inputs)
221
  answer_postprocessed = QAEngine._postprocess_answer(answer)
222
  response.set_answer(answer_postprocessed)
223
  logger.info('Received answer')
requirements.txt CHANGED
@@ -1,5 +1,4 @@
1
  torch
2
- torchvision
3
  transformers
4
  accelerate
5
  einops
 
1
  torch
 
2
  transformers
3
  accelerate
4
  einops