akfung commited on
Commit
dff38c9
·
1 Parent(s): afc30f3

hf formatting and prompting

Browse files
Files changed (3) hide show
  1. setup.py +1 -1
  2. src/config.py +2 -1
  3. src/model.py +5 -5
setup.py CHANGED
@@ -1,4 +1,4 @@
1
  from sentence_transformers import SentenceTransformer
2
 
3
  embedding_model = SentenceTransformer('multi-qa-mpnet-base-dot-v1')
4
- embedding_model.save('embedding_model/')
 
1
  from sentence_transformers import SentenceTransformer
2
 
3
  embedding_model = SentenceTransformer('multi-qa-mpnet-base-dot-v1')
4
+ embedding_model.save('/embedding_model/')
src/config.py CHANGED
@@ -10,10 +10,12 @@ if env == 'local':
10
  model_path = '../merged_llama2/'
11
  tokenizer_path = '../Llama2/7B/tokenizer.model'
12
  load_dotenv()
 
13
 
14
  elif env == 'spaces':
15
  model_path = 'akfung/llama_supreme'
16
  tokenizer_path = 'akfung/llama_supreme'
 
17
 
18
  elif env == 'gcp':
19
  model_path = 'model/'
@@ -47,7 +49,6 @@ headers = {
47
  "Content-Type": "application/json"
48
  }
49
 
50
- embedding_path = os.environ.get('EMBEDDING_PATH')
51
  streaming_url = os.environ.get('STREAMING_URL')
52
  job_url = os.environ.get('JOB_URL')
53
 
 
10
  model_path = '../merged_llama2/'
11
  tokenizer_path = '../Llama2/7B/tokenizer.model'
12
  load_dotenv()
13
+ embedding_path = "embedding_model/"
14
 
15
  elif env == 'spaces':
16
  model_path = 'akfung/llama_supreme'
17
  tokenizer_path = 'akfung/llama_supreme'
18
+ embedding_path = "/embedding_model/"
19
 
20
  elif env == 'gcp':
21
  model_path = 'model/'
 
49
  "Content-Type": "application/json"
50
  }
51
 
 
52
  streaming_url = os.environ.get('STREAMING_URL')
53
  job_url = os.environ.get('JOB_URL')
54
 
src/model.py CHANGED
@@ -4,7 +4,7 @@ import time
4
  # from google.cloud import storage
5
  from sentence_transformers import SentenceTransformer
6
 
7
- from .config import max_new_tokens, streaming_url, job_url, default_payload, headers
8
  from .db.db_utilities import query_db
9
 
10
  class Model:
@@ -15,7 +15,7 @@ class Model:
15
  max_new_tokens:int=max_new_tokens):
16
  self.max_new_tokens = max_new_tokens
17
  # self.embedding_model = SentenceTransformer('multi-qa-mpnet-base-dot-v1')
18
- self.embedding_model = SentenceTransformer("embedding_model/")
19
 
20
 
21
  def inference(self, query:str, table:str):
@@ -35,11 +35,11 @@ class Model:
35
  if len(matches) > 0:
36
  match = '"""' + matches[0][0] + '"""'
37
 
38
- context = "You are the United States Supreme Court. Use the following historical opinion delimited by triple quotes to give your ruling on a court case description. Historical opinion: " + match
39
  else:
40
  context = 'You are the United States Supreme Court. Give your ruling on a court case description.'
41
 
42
- return context + " Answer in less than 400 words. Do not introduce yourself"
43
 
44
  def query_model(self, query:str, table:str, default_payload:dict=default_payload, timeout:int=60, **kwargs) -> str:
45
  """Query the model api on runpod. Runs for 60s by default. Generator response until job is complete"""
@@ -57,7 +57,7 @@ class Model:
57
  "content": query,
58
  }
59
  ]
60
- print(augmented_prompt_template)
61
  default_payload["input"]["prompt"] = augmented_prompt_template
62
  job_id = requests.post(job_url, json=default_payload, headers=headers).json()['id']
63
  for i in range(timeout):
 
4
  # from google.cloud import storage
5
  from sentence_transformers import SentenceTransformer
6
 
7
+ from .config import max_new_tokens, streaming_url, job_url, default_payload, headers, embedding_path
8
  from .db.db_utilities import query_db
9
 
10
  class Model:
 
15
  max_new_tokens:int=max_new_tokens):
16
  self.max_new_tokens = max_new_tokens
17
  # self.embedding_model = SentenceTransformer('multi-qa-mpnet-base-dot-v1')
18
+ self.embedding_model = SentenceTransformer(embedding_path)
19
 
20
 
21
  def inference(self, query:str, table:str):
 
35
  if len(matches) > 0:
36
  match = '"""' + matches[0][0] + '"""'
37
 
38
+ context = "You are the United States Supreme Court. Use the following historical opinion to give your ruling on a court case description. Historical opinion: " + match
39
  else:
40
  context = 'You are the United States Supreme Court. Give your ruling on a court case description.'
41
 
42
+ return context + " Answer in less than 400 words in the format Opinion: <opinion> "
43
 
44
  def query_model(self, query:str, table:str, default_payload:dict=default_payload, timeout:int=60, **kwargs) -> str:
45
  """Query the model api on runpod. Runs for 60s by default. Generator response until job is complete"""
 
57
  "content": query,
58
  }
59
  ]
60
+
61
  default_payload["input"]["prompt"] = augmented_prompt_template
62
  job_id = requests.post(job_url, json=default_payload, headers=headers).json()['id']
63
  for i in range(timeout):