mishtert commited on
Commit
9ccafee
1 Parent(s): cd5fdb5

Upload hfutils.py

Browse files
Files changed (1) hide show
  1. hfutils.py +41 -0
hfutils.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import requests
2
+
3
+ # API_URL = "https://api-inference.huggingface.co/models/facebook/bart-large-mnli"
4
+ # headers = {"Authorization": f"Bearer {hft}"}
5
+
6
+ API_URL = "https://api-inference.huggingface.co/models/deepset/roberta-base-squad2"
7
+ headers = {"Authorization": "Bearer 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"}
8
+
9
+
10
+
11
+ def query(payload):
12
+ response = requests.post(API_URL, headers=headers, json=payload)
13
+ return response.json()
14
+
15
+ # output = query({
16
+ # "inputs": {
17
+ # "question": "What's my name?",
18
+ # "context": "My name is Clara and I live in Berkeley.",
19
+ # },
20
+ # })
21
+
22
+
23
+ def get_ans(question,context):
24
+ output = query({
25
+ "inputs": {
26
+ "question": question,
27
+ "context": context,
28
+ },
29
+ })
30
+ return output
31
+
32
+
33
+
34
+ def get_label_score_dict(row, threshold):
35
+ result_dict = dict()
36
+ for _label, _score in zip(row['labels'], row['scores']):
37
+ if _score > threshold:
38
+ result_dict.update({_label: 1})
39
+ else:
40
+ result_dict.update({_label: 0})
41
+ return result_dict