jordiclive commited on
Commit
da58751
1 Parent(s): fcd785c

Upload handler.py

Browse files
Files changed (1) hide show
  1. handler.py +73 -0
handler.py ADDED
@@ -0,0 +1,73 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any, Dict, List
2
+
3
+ import torch
4
+ from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
5
+
6
+ dtype = torch.bfloat16
7
+
8
+
9
+ class EndpointHandler:
10
+ def __init__(self, path=""):
11
+ # load the model
12
+ self.tokenizer = AutoTokenizer.from_pretrained(path)
13
+ self.model = AutoModelForCausalLM.from_pretrained(
14
+ path, device_map="auto", torch_dtype=dtype
15
+ )
16
+ if self.tokenizer.pad_token is None:
17
+ self.tokenizer.pad_token = self.tokenizer.eos_token
18
+ # create inference pipeline
19
+ self.pipeline = pipeline(
20
+ "text-generation", model=self.model, tokenizer=self.tokenizer
21
+ )
22
+ self.ce = torch.nn.CrossEntropyLoss(
23
+ ignore_index=self.tokenizer.pad_token_id, reduction="none"
24
+ )
25
+
26
+ def compute_log_likelihood(self, lm_logits, input_ids):
27
+ predictions = lm_logits[..., :-1, :].contiguous()
28
+ target_ids = input_ids[..., 1:].contiguous()
29
+
30
+ ce_loss = self.ce(
31
+ predictions.view(-1, predictions.size(-1)),
32
+ target_ids.view(-1),
33
+ )
34
+ return -ce_loss.view_as(target_ids)[0]
35
+
36
+ def __call__(self, data: Any):
37
+ inputs = data.pop("inputs", data)
38
+ parameters = data.pop("parameters", None)
39
+ if parameters.get("no_generation", False):
40
+ input_tokens = self.tokenizer.batch_encode_plus(
41
+ [inputs], return_tensors="pt", padding=False
42
+ )
43
+ for t in input_tokens:
44
+ if torch.is_tensor(input_tokens[t]):
45
+ input_tokens[t] = input_tokens[t].to(torch.cuda.current_device())
46
+
47
+ logits = self.model(
48
+ input_ids=input_tokens["input_ids"],
49
+ attention_mask=input_tokens["attention_mask"],
50
+ )[0]
51
+ log_likelihood = self.compute_log_likelihood(
52
+ logits, input_tokens["input_ids"]
53
+ )
54
+ return (logits, log_likelihood)
55
+ if parameters is not None:
56
+ prediction = self.pipeline(inputs, **parameters)
57
+ else:
58
+ prediction = self.pipeline(inputs)
59
+ return prediction
60
+
61
+
62
+ # if __name__ == "__main__":
63
+ # model = EndpointHandler("TinyLlama/TinyLlama-1.1B-Chat-v1.0")
64
+
65
+ # data = {
66
+ # "inputs": "Can you please let us know more details about your ",
67
+ # "parameters": {
68
+ # "no_generation": True,
69
+ # # "function_to_apply": "none",
70
+ # # "return_text": False,
71
+ # },
72
+ # }
73
+ # x = model(data)