ksee commited on
Commit
9c7155b
1 Parent(s): 9e4637a

Upload 2 files

Browse files
Files changed (2) hide show
  1. helpers.py +45 -0
  2. huggingface.py +38 -0
helpers.py ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from db import DB
2
+
3
+ def validate_key(key, db: DB):
4
+ """ Validate the survey key. """
5
+ generated_keys = db.get_document_ids('surveykeys')
6
+ survey_done = db.get_document_ids('responses')
7
+ if (key in generated_keys) and (key not in survey_done):
8
+ return True
9
+ return False
10
+
11
+ def label_speaker(speaker):
12
+ if speaker == 'You':
13
+ return '### Human'
14
+ elif speaker == 'Other Party':
15
+ return '### Assistant'
16
+
17
+ def expand_for(num):
18
+ if num < 2:
19
+ return True
20
+ else:
21
+ return False
22
+
23
+ def check_none_in_dict(dictionary):
24
+ for key, value in dictionary.items():
25
+ if value is None:
26
+ print(f"Found 'None' value for key: {key}")
27
+ return True
28
+ print("No 'None' values found in the dictionary.")
29
+ return False
30
+
31
+ def check_conversation_input_validity(lst):
32
+ for i in range(len(lst)):
33
+ item = lst[i]
34
+ if item is None:
35
+ if i < 2:
36
+ return False
37
+ last_item = lst[i-1]
38
+ if last_item[0:9] == '### Human':
39
+ return True
40
+ else:
41
+ return False
42
+
43
+ def wrap_conversation(text: str):
44
+ wrapped_text = '### Human: ' + text + " ### Assistant:"
45
+ return wrapped_text
huggingface.py ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from huggingface_hub import InferenceClient
2
+ import requests
3
+
4
+ from helpers import wrap_conversation
5
+
6
+ class Model:
7
+
8
+ def __init__(self, model_name: str, API_URL: str):
9
+ self.model_name = model_name
10
+ self.API_URL = API_URL
11
+
12
+ def _query(self, payload):
13
+ headers = {
14
+ # "Authorization": "Bearer XXX",
15
+ "Content-Type": "application/json"
16
+ }
17
+ response = requests.post(self.API_URL, headers=headers, json=payload)
18
+ return response.json()
19
+
20
+ def generate_text(self, text_input: str) -> str:
21
+ # output = self.client.text_generation(text_input)
22
+ # return output
23
+ payload = {
24
+ "inputs": wrap_conversation(text_input),
25
+ "parameters": {
26
+ "temperature": 0.9,
27
+ "max_new_tokens": 50
28
+ }
29
+ }
30
+
31
+ output = self._query(payload)
32
+ return self._get_first_reply(text_input, output['generated_text'])
33
+
34
+ @staticmethod
35
+ def _get_first_reply(original_text, output_text):
36
+ new_text = output_text[len(wrap_conversation(original_text)):]
37
+ final_text = new_text.split('#', 1)[0]
38
+ return final_text