alexadam commited on
Commit
826348d
·
verified ·
1 Parent(s): 30126e9

Create handler.py

Browse files
Files changed (1) hide show
  1. handler.py +50 -0
handler.py ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict
2
+ from transformers import AutoModelForCausalLM, AutoTokenizer
3
+ import torch
4
+
5
+
6
+
7
+ class EndpointHandler():
8
+ def __init__(self, path=""):
9
+ # load the model
10
+ self.model = transformers.AutoModelForCausalLM.from_pretrained(
11
+ "gpt2", torch_dtype=torch.float16, output_hidden_states=True
12
+ )
13
+ self.model = self.model.cuda()
14
+ self.tokenizer = AutoTokenizer.from_pretrained("gpt2")
15
+
16
+
17
+ def __call__(self, data: Dict[str, bytes]) -> Dict[str, str]:
18
+ """
19
+ Args:
20
+ data (:obj:):
21
+ includes the deserialized audio file as bytes
22
+ Return:
23
+ A :obj:`dict`:. base64 encoded image
24
+ """
25
+ # process input
26
+ inputs = data.pop("inputs", data)
27
+ all_logits = []
28
+
29
+ for doc in inputs:
30
+ tokenized = self.tokenizer(
31
+ inputs,
32
+ return_tensors="pt",
33
+ padding="max_length",
34
+ truncation=True,
35
+ max_length=512,
36
+ )
37
+ token_ids, token_mask = tokens.input_ids.cuda(), tokens.attention_mask.cuda()
38
+ with torch.no_grad():
39
+ out = model(token_ids, attention_mask=token_mask)
40
+ meaned_logits = (out.logits * token_mask.unsqueeze(-1)).sum(1) / token_mask.sum(
41
+ 1
42
+ ).unsqueeze(-1)
43
+ sorted_logits = torch.sort(out.logits).values
44
+ mean_sorted_logits = (sorted_logits * token_mask.unsqueeze(-1)).sum(
45
+ 1
46
+ ) / token_mask.sum(1).unsqueeze(-1)
47
+ all_logits.append(meaned_logits.cpu().numpy().tolist())
48
+
49
+ # postprocess the prediction
50
+ return {"logits": all_logits}