umm-maybe commited on
Commit
65dce62
1 Parent(s): 86edf8d

Upload hf_utils.py

Browse files
Files changed (1) hide show
  1. hf_utils.py +51 -0
hf_utils.py ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import requests
2
+ import time
3
+ import re
4
+
5
+ # function for Huggingface API calls
6
+ def query(payload, model_path, headers):
7
+ API_URL = "https://api-inference.huggingface.co/models/" + model_path
8
+ for retry in range(3):
9
+ response = requests.post(API_URL, headers=headers, json=payload)
10
+ if response.status_code == requests.codes.ok:
11
+ try:
12
+ results = response.json()
13
+ return results
14
+ except:
15
+ print('Invalid response received from server')
16
+ print(response)
17
+ return None
18
+ else:
19
+ # Not connected to internet maybe?
20
+ if response.status_code==404:
21
+ print('Are you connected to the internet?')
22
+ print('URL attempted = '+API_URL)
23
+ break
24
+ if response.status_code==503:
25
+ print(response.json()['error'])
26
+ time.sleep(response.json()['estimated_time'])
27
+ continue
28
+ if response.status_code==504:
29
+ print('504 Gateway Timeout')
30
+ else:
31
+ print('Unsuccessful request, status code '+ str(response.status_code))
32
+ # print(response.json()) #debug only
33
+ print(payload)
34
+
35
+ def generate_text(prompt, model_path, text_generation_parameters, headers):
36
+ start_time = time.time()
37
+ options = {'use_cache': False, 'wait_for_model': True}
38
+ payload = {"inputs": prompt, "parameters": text_generation_parameters, "options": options}
39
+ output_list = query(payload, model_path, headers)
40
+ if not output_list:
41
+ print('Generation failed')
42
+ end_time = time.time()
43
+ duration = round(end_time - start_time, 1)
44
+ stringlist = []
45
+ if output_list and 'generated_text' in output_list[0].keys():
46
+ print(f'{len(output_list)} sample(s) of text generated in {duration} seconds.')
47
+ for gendict in output_list:
48
+ stringlist.append(gendict['generated_text'])
49
+ else:
50
+ print(output_list)
51
+ return(stringlist)