Spaces:
Sleeping
Sleeping
hf formatting and prompting
Browse files- setup.py +1 -1
- src/config.py +2 -1
- src/model.py +5 -5
setup.py
CHANGED
@@ -1,4 +1,4 @@
|
|
1 |
from sentence_transformers import SentenceTransformer
|
2 |
|
3 |
embedding_model = SentenceTransformer('multi-qa-mpnet-base-dot-v1')
|
4 |
-
embedding_model.save('embedding_model/')
|
|
|
1 |
from sentence_transformers import SentenceTransformer
|
2 |
|
3 |
embedding_model = SentenceTransformer('multi-qa-mpnet-base-dot-v1')
|
4 |
+
embedding_model.save('/embedding_model/')
|
src/config.py
CHANGED
@@ -10,10 +10,12 @@ if env == 'local':
|
|
10 |
model_path = '../merged_llama2/'
|
11 |
tokenizer_path = '../Llama2/7B/tokenizer.model'
|
12 |
load_dotenv()
|
|
|
13 |
|
14 |
elif env == 'spaces':
|
15 |
model_path = 'akfung/llama_supreme'
|
16 |
tokenizer_path = 'akfung/llama_supreme'
|
|
|
17 |
|
18 |
elif env == 'gcp':
|
19 |
model_path = 'model/'
|
@@ -47,7 +49,6 @@ headers = {
|
|
47 |
"Content-Type": "application/json"
|
48 |
}
|
49 |
|
50 |
-
embedding_path = os.environ.get('EMBEDDING_PATH')
|
51 |
streaming_url = os.environ.get('STREAMING_URL')
|
52 |
job_url = os.environ.get('JOB_URL')
|
53 |
|
|
|
10 |
model_path = '../merged_llama2/'
|
11 |
tokenizer_path = '../Llama2/7B/tokenizer.model'
|
12 |
load_dotenv()
|
13 |
+
embedding_path = "embedding_model/"
|
14 |
|
15 |
elif env == 'spaces':
|
16 |
model_path = 'akfung/llama_supreme'
|
17 |
tokenizer_path = 'akfung/llama_supreme'
|
18 |
+
embedding_path = "/embedding_model/"
|
19 |
|
20 |
elif env == 'gcp':
|
21 |
model_path = 'model/'
|
|
|
49 |
"Content-Type": "application/json"
|
50 |
}
|
51 |
|
|
|
52 |
streaming_url = os.environ.get('STREAMING_URL')
|
53 |
job_url = os.environ.get('JOB_URL')
|
54 |
|
src/model.py
CHANGED
@@ -4,7 +4,7 @@ import time
|
|
4 |
# from google.cloud import storage
|
5 |
from sentence_transformers import SentenceTransformer
|
6 |
|
7 |
-
from .config import max_new_tokens, streaming_url, job_url, default_payload, headers
|
8 |
from .db.db_utilities import query_db
|
9 |
|
10 |
class Model:
|
@@ -15,7 +15,7 @@ class Model:
|
|
15 |
max_new_tokens:int=max_new_tokens):
|
16 |
self.max_new_tokens = max_new_tokens
|
17 |
# self.embedding_model = SentenceTransformer('multi-qa-mpnet-base-dot-v1')
|
18 |
-
self.embedding_model = SentenceTransformer(
|
19 |
|
20 |
|
21 |
def inference(self, query:str, table:str):
|
@@ -35,11 +35,11 @@ class Model:
|
|
35 |
if len(matches) > 0:
|
36 |
match = '"""' + matches[0][0] + '"""'
|
37 |
|
38 |
-
context = "You are the United States Supreme Court. Use the following historical opinion
|
39 |
else:
|
40 |
context = 'You are the United States Supreme Court. Give your ruling on a court case description.'
|
41 |
|
42 |
-
return context + " Answer in less than 400 words
|
43 |
|
44 |
def query_model(self, query:str, table:str, default_payload:dict=default_payload, timeout:int=60, **kwargs) -> str:
|
45 |
"""Query the model api on runpod. Runs for 60s by default. Generator response until job is complete"""
|
@@ -57,7 +57,7 @@ class Model:
|
|
57 |
"content": query,
|
58 |
}
|
59 |
]
|
60 |
-
|
61 |
default_payload["input"]["prompt"] = augmented_prompt_template
|
62 |
job_id = requests.post(job_url, json=default_payload, headers=headers).json()['id']
|
63 |
for i in range(timeout):
|
|
|
4 |
# from google.cloud import storage
|
5 |
from sentence_transformers import SentenceTransformer
|
6 |
|
7 |
+
from .config import max_new_tokens, streaming_url, job_url, default_payload, headers, embedding_path
|
8 |
from .db.db_utilities import query_db
|
9 |
|
10 |
class Model:
|
|
|
15 |
max_new_tokens:int=max_new_tokens):
|
16 |
self.max_new_tokens = max_new_tokens
|
17 |
# self.embedding_model = SentenceTransformer('multi-qa-mpnet-base-dot-v1')
|
18 |
+
self.embedding_model = SentenceTransformer(embedding_path)
|
19 |
|
20 |
|
21 |
def inference(self, query:str, table:str):
|
|
|
35 |
if len(matches) > 0:
|
36 |
match = '"""' + matches[0][0] + '"""'
|
37 |
|
38 |
+
context = "You are the United States Supreme Court. Use the following historical opinion to give your ruling on a court case description. Historical opinion: " + match
|
39 |
else:
|
40 |
context = 'You are the United States Supreme Court. Give your ruling on a court case description.'
|
41 |
|
42 |
+
return context + " Answer in less than 400 words in the format Opinion: <opinion> "
|
43 |
|
44 |
def query_model(self, query:str, table:str, default_payload:dict=default_payload, timeout:int=60, **kwargs) -> str:
|
45 |
"""Query the model api on runpod. Runs for 60s by default. Generator response until job is complete"""
|
|
|
57 |
"content": query,
|
58 |
}
|
59 |
]
|
60 |
+
|
61 |
default_payload["input"]["prompt"] = augmented_prompt_template
|
62 |
job_id = requests.post(job_url, json=default_payload, headers=headers).json()['id']
|
63 |
for i in range(timeout):
|