renede commited on
Commit
4675617
1 Parent(s): 25b17fe

Delete handler.py

Browse files
Files changed (1) hide show
  1. handler.py +0 -37
handler.py DELETED
@@ -1,37 +0,0 @@
1
- from typing import Any, Dict, List
2
-
3
- import torch
4
- import transformers
5
- from transformers import AutoModelForCausalLM, AutoTokenizer
6
-
7
- dtype = torch.bfloat16 if torch.cuda.get_device_capability()[0] == 8 else torch.float16
8
-
9
-
10
- class EndpointHandler:
11
- def __init__(self, path=""):
12
- tokenizer = AutoTokenizer.from_pretrained(path, trust_remote_code=True)
13
- model = AutoModelForCausalLM.from_pretrained(
14
- path,
15
- return_dict=True,
16
- device_map="auto",
17
- load_in_8bit=True,
18
- torch_dtype=dtype,
19
- trust_remote_code=True,
20
- )
21
-
22
- generation_config = model.generation_config
23
- generation_config.max_new_tokens = 60
24
- generation_config.temperature = 0
25
- generation_config.num_return_sequences = 1
26
- generation_config.pad_token_id = tokenizer.eos_token_id
27
- generation_config.eos_token_id = tokenizer.eos_token_id
28
- self.generation_config = generation_config
29
-
30
- self.pipeline = transformers.pipeline(
31
- "text-generation", model=model, tokenizer=tokenizer
32
- )
33
-
34
- def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
35
- prompt = data.pop("inputs", data)
36
- result = self.pipeline(prompt, generation_config=self.generation_config)
37
- return result