Spaces:
Sleeping
Sleeping
fixed chat template issues
Browse files- setup.py +1 -1
- src/config.py +3 -2
- src/model.py +18 -7
setup.py
CHANGED
@@ -1,4 +1,4 @@
|
|
1 |
from sentence_transformers import SentenceTransformer
|
2 |
|
3 |
embedding_model = SentenceTransformer('multi-qa-mpnet-base-dot-v1')
|
4 |
-
embedding_model.save('
|
|
|
1 |
from sentence_transformers import SentenceTransformer
|
2 |
|
3 |
embedding_model = SentenceTransformer('multi-qa-mpnet-base-dot-v1')
|
4 |
+
embedding_model.save('embedding_model/')
|
src/config.py
CHANGED
@@ -48,11 +48,12 @@ headers = {
|
|
48 |
}
|
49 |
|
50 |
embedding_path = os.environ.get('EMBEDDING_PATH')
|
51 |
-
streaming_url =
|
52 |
-
job_url =
|
53 |
|
54 |
default_payload = { "input": {
|
55 |
"prompt": "Who is the president of the United States?",
|
|
|
56 |
"sampling_params": {
|
57 |
"max_tokens": os.environ.get('max_new_tokens', 400),
|
58 |
"n": 1,
|
|
|
48 |
}
|
49 |
|
50 |
embedding_path = os.environ.get('EMBEDDING_PATH')
|
51 |
+
streaming_url = os.environ.get('STREAMING_URL')
|
52 |
+
job_url = os.environ.get('JOB_URL')
|
53 |
|
54 |
default_payload = { "input": {
|
55 |
"prompt": "Who is the president of the United States?",
|
56 |
+
"apply_chat_template": True,
|
57 |
"sampling_params": {
|
58 |
"max_tokens": os.environ.get('max_new_tokens', 400),
|
59 |
"n": 1,
|
src/model.py
CHANGED
@@ -15,7 +15,7 @@ class Model:
|
|
15 |
max_new_tokens:int=max_new_tokens):
|
16 |
self.max_new_tokens = max_new_tokens
|
17 |
# self.embedding_model = SentenceTransformer('multi-qa-mpnet-base-dot-v1')
|
18 |
-
self.embedding_model = SentenceTransformer("
|
19 |
|
20 |
|
21 |
def inference(self, query:str, table:str):
|
@@ -35,19 +35,30 @@ class Model:
|
|
35 |
if len(matches) > 0:
|
36 |
match = '"""' + matches[0][0] + '"""'
|
37 |
|
38 |
-
context = "Use the following historical opinion delimited by
|
39 |
else:
|
40 |
-
context = 'Give your ruling on a court case description.
|
41 |
|
42 |
-
return context +
|
43 |
|
44 |
def query_model(self, query:str, table:str, default_payload:dict=default_payload, timeout:int=60, **kwargs) -> str:
|
45 |
"""Query the model api on runpod. Runs for 60s by default. Generator response until job is complete"""
|
46 |
|
47 |
-
|
48 |
for k,v in kwargs:
|
49 |
default_payload['input']['sampling_params'][k] = v
|
50 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
51 |
job_id = requests.post(job_url, json=default_payload, headers=headers).json()['id']
|
52 |
for i in range(timeout):
|
53 |
time.sleep(1)
|
@@ -77,4 +88,4 @@ class Model:
|
|
77 |
# for object_name in model_file_paths:
|
78 |
# blob = bucket.blob(object_name)
|
79 |
# blob.download_to_filename(object_name)
|
80 |
-
|
|
|
15 |
max_new_tokens:int=max_new_tokens):
|
16 |
self.max_new_tokens = max_new_tokens
|
17 |
# self.embedding_model = SentenceTransformer('multi-qa-mpnet-base-dot-v1')
|
18 |
+
self.embedding_model = SentenceTransformer("embedding_model/")
|
19 |
|
20 |
|
21 |
def inference(self, query:str, table:str):
|
|
|
35 |
if len(matches) > 0:
|
36 |
match = '"""' + matches[0][0] + '"""'
|
37 |
|
38 |
+
context = "You are the United States Supreme Court. Use the following historical opinion delimited by triple quotes to give your ruling on a court case description. Historical opinion: " + match
|
39 |
else:
|
40 |
+
context = 'You are the United States Supreme Court. Give your ruling on a court case description.'
|
41 |
|
42 |
+
return context + " Answer in less than 400 words. Do not introduce yourself"
|
43 |
|
44 |
def query_model(self, query:str, table:str, default_payload:dict=default_payload, timeout:int=60, **kwargs) -> str:
|
45 |
"""Query the model api on runpod. Runs for 60s by default. Generator response until job is complete"""
|
46 |
|
47 |
+
context = self.get_context(query=query, table=table)
|
48 |
for k,v in kwargs:
|
49 |
default_payload['input']['sampling_params'][k] = v
|
50 |
+
augmented_prompt_template = [
|
51 |
+
{
|
52 |
+
"role": "system",
|
53 |
+
"content": context,
|
54 |
+
},
|
55 |
+
{
|
56 |
+
"role": "user",
|
57 |
+
"content": query,
|
58 |
+
}
|
59 |
+
]
|
60 |
+
print(augmented_prompt_template)
|
61 |
+
default_payload["input"]["prompt"] = augmented_prompt_template
|
62 |
job_id = requests.post(job_url, json=default_payload, headers=headers).json()['id']
|
63 |
for i in range(timeout):
|
64 |
time.sleep(1)
|
|
|
88 |
# for object_name in model_file_paths:
|
89 |
# blob = bucket.blob(object_name)
|
90 |
# blob.download_to_filename(object_name)
|
91 |
+
|