akfung commited on
Commit
afc30f3
·
1 Parent(s): b0be57c

fixed chat template issues

Browse files
Files changed (3) hide show
  1. setup.py +1 -1
  2. src/config.py +3 -2
  3. src/model.py +18 -7
setup.py CHANGED
@@ -1,4 +1,4 @@
1
  from sentence_transformers import SentenceTransformer
2
 
3
  embedding_model = SentenceTransformer('multi-qa-mpnet-base-dot-v1')
4
- embedding_model.save('/embedding_model/')
 
1
  from sentence_transformers import SentenceTransformer
2
 
3
  embedding_model = SentenceTransformer('multi-qa-mpnet-base-dot-v1')
4
+ embedding_model.save('embedding_model/')
src/config.py CHANGED
@@ -48,11 +48,12 @@ headers = {
48
  }
49
 
50
  embedding_path = os.environ.get('EMBEDDING_PATH')
51
- streaming_url = "https://api.runpod.ai/v2/o4tke61qpopsz0/stream/"
52
- job_url = "https://api.runpod.ai/v2/o4tke61qpopsz0/run"
53
 
54
  default_payload = { "input": {
55
  "prompt": "Who is the president of the United States?",
 
56
  "sampling_params": {
57
  "max_tokens": os.environ.get('max_new_tokens', 400),
58
  "n": 1,
 
48
  }
49
 
50
  embedding_path = os.environ.get('EMBEDDING_PATH')
51
+ streaming_url = os.environ.get('STREAMING_URL')
52
+ job_url = os.environ.get('JOB_URL')
53
 
54
  default_payload = { "input": {
55
  "prompt": "Who is the president of the United States?",
56
+ "apply_chat_template": True,
57
  "sampling_params": {
58
  "max_tokens": os.environ.get('max_new_tokens', 400),
59
  "n": 1,
src/model.py CHANGED
@@ -15,7 +15,7 @@ class Model:
15
  max_new_tokens:int=max_new_tokens):
16
  self.max_new_tokens = max_new_tokens
17
  # self.embedding_model = SentenceTransformer('multi-qa-mpnet-base-dot-v1')
18
- self.embedding_model = SentenceTransformer("/embedding_model/")
19
 
20
 
21
  def inference(self, query:str, table:str):
@@ -35,19 +35,30 @@ class Model:
35
  if len(matches) > 0:
36
  match = '"""' + matches[0][0] + '"""'
37
 
38
- context = "Use the following historical opinion delimited by tripple quotes to give your ruling on a court case description. " + match + " Description: "
39
  else:
40
- context = 'Give your ruling on a court case description. Description:'
41
 
42
- return context + query + " Answer in less than 400 words and without a self introduction."
43
 
44
  def query_model(self, query:str, table:str, default_payload:dict=default_payload, timeout:int=60, **kwargs) -> str:
45
  """Query the model api on runpod. Runs for 60s by default. Generator response until job is complete"""
46
 
47
- augmented_prompt = self.get_context(query=query, table=table)
48
  for k,v in kwargs:
49
  default_payload['input']['sampling_params'][k] = v
50
- default_payload["input"]["prompt"] = augmented_prompt
 
 
 
 
 
 
 
 
 
 
 
51
  job_id = requests.post(job_url, json=default_payload, headers=headers).json()['id']
52
  for i in range(timeout):
53
  time.sleep(1)
@@ -77,4 +88,4 @@ class Model:
77
  # for object_name in model_file_paths:
78
  # blob = bucket.blob(object_name)
79
  # blob.download_to_filename(object_name)
80
-
 
15
  max_new_tokens:int=max_new_tokens):
16
  self.max_new_tokens = max_new_tokens
17
  # self.embedding_model = SentenceTransformer('multi-qa-mpnet-base-dot-v1')
18
+ self.embedding_model = SentenceTransformer("embedding_model/")
19
 
20
 
21
  def inference(self, query:str, table:str):
 
35
  if len(matches) > 0:
36
  match = '"""' + matches[0][0] + '"""'
37
 
38
+ context = "You are the United States Supreme Court. Use the following historical opinion delimited by triple quotes to give your ruling on a court case description. Historical opinion: " + match
39
  else:
40
+ context = 'You are the United States Supreme Court. Give your ruling on a court case description.'
41
 
42
+ return context + " Answer in less than 400 words. Do not introduce yourself"
43
 
44
  def query_model(self, query:str, table:str, default_payload:dict=default_payload, timeout:int=60, **kwargs) -> str:
45
  """Query the model api on runpod. Runs for 60s by default. Generator response until job is complete"""
46
 
47
+ context = self.get_context(query=query, table=table)
48
  for k,v in kwargs:
49
  default_payload['input']['sampling_params'][k] = v
50
+ augmented_prompt_template = [
51
+ {
52
+ "role": "system",
53
+ "content": context,
54
+ },
55
+ {
56
+ "role": "user",
57
+ "content": query,
58
+ }
59
+ ]
60
+ print(augmented_prompt_template)
61
+ default_payload["input"]["prompt"] = augmented_prompt_template
62
  job_id = requests.post(job_url, json=default_payload, headers=headers).json()['id']
63
  for i in range(timeout):
64
  time.sleep(1)
 
88
  # for object_name in model_file_paths:
89
  # blob = bucket.blob(object_name)
90
  # blob.download_to_filename(object_name)
91
+