Spaces:
Sleeping
Sleeping
Upload 2 files
Browse files- helpers.py +45 -0
- huggingface.py +38 -0
helpers.py
ADDED
@@ -0,0 +1,45 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from db import DB
|
2 |
+
|
3 |
+
def validate_key(key, db: DB):
|
4 |
+
""" Validate the survey key. """
|
5 |
+
generated_keys = db.get_document_ids('surveykeys')
|
6 |
+
survey_done = db.get_document_ids('responses')
|
7 |
+
if (key in generated_keys) and (key not in survey_done):
|
8 |
+
return True
|
9 |
+
return False
|
10 |
+
|
11 |
+
def label_speaker(speaker):
|
12 |
+
if speaker == 'You':
|
13 |
+
return '### Human'
|
14 |
+
elif speaker == 'Other Party':
|
15 |
+
return '### Assistant'
|
16 |
+
|
17 |
+
def expand_for(num):
|
18 |
+
if num < 2:
|
19 |
+
return True
|
20 |
+
else:
|
21 |
+
return False
|
22 |
+
|
23 |
+
def check_none_in_dict(dictionary):
|
24 |
+
for key, value in dictionary.items():
|
25 |
+
if value is None:
|
26 |
+
print(f"Found 'None' value for key: {key}")
|
27 |
+
return True
|
28 |
+
print("No 'None' values found in the dictionary.")
|
29 |
+
return False
|
30 |
+
|
31 |
+
def check_conversation_input_validity(lst):
|
32 |
+
for i in range(len(lst)):
|
33 |
+
item = lst[i]
|
34 |
+
if item is None:
|
35 |
+
if i < 2:
|
36 |
+
return False
|
37 |
+
last_item = lst[i-1]
|
38 |
+
if last_item[0:9] == '### Human':
|
39 |
+
return True
|
40 |
+
else:
|
41 |
+
return False
|
42 |
+
|
43 |
+
def wrap_conversation(text: str):
|
44 |
+
wrapped_text = '### Human: ' + text + " ### Assistant:"
|
45 |
+
return wrapped_text
|
huggingface.py
ADDED
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from huggingface_hub import InferenceClient
|
2 |
+
import requests
|
3 |
+
|
4 |
+
from helpers import wrap_conversation
|
5 |
+
|
6 |
+
class Model:
|
7 |
+
|
8 |
+
def __init__(self, model_name: str, API_URL: str):
|
9 |
+
self.model_name = model_name
|
10 |
+
self.API_URL = API_URL
|
11 |
+
|
12 |
+
def _query(self, payload):
|
13 |
+
headers = {
|
14 |
+
# "Authorization": "Bearer XXX",
|
15 |
+
"Content-Type": "application/json"
|
16 |
+
}
|
17 |
+
response = requests.post(self.API_URL, headers=headers, json=payload)
|
18 |
+
return response.json()
|
19 |
+
|
20 |
+
def generate_text(self, text_input: str) -> str:
|
21 |
+
# output = self.client.text_generation(text_input)
|
22 |
+
# return output
|
23 |
+
payload = {
|
24 |
+
"inputs": wrap_conversation(text_input),
|
25 |
+
"parameters": {
|
26 |
+
"temperature": 0.9,
|
27 |
+
"max_new_tokens": 50
|
28 |
+
}
|
29 |
+
}
|
30 |
+
|
31 |
+
output = self._query(payload)
|
32 |
+
return self._get_first_reply(text_input, output['generated_text'])
|
33 |
+
|
34 |
+
@staticmethod
|
35 |
+
def _get_first_reply(original_text, output_text):
|
36 |
+
new_text = output_text[len(wrap_conversation(original_text)):]
|
37 |
+
final_text = new_text.split('#', 1)[0]
|
38 |
+
return final_text
|