fwittel commited on
Commit
22bc6be
1 Parent(s): d4725fd

switch to AutoModelForCausalLM

Browse files
Files changed (2) hide show
  1. handler.py +6 -6
  2. test.py +0 -13
handler.py CHANGED
@@ -1,21 +1,21 @@
1
  import torch
2
  from typing import Dict, List, Any
3
- from transformers import AutoModel, AutoTokenizer, pipeline
4
 
5
  # check for GPU
6
  device = 0 if torch.cuda.is_available() else -1
7
 
 
8
  class EndpointHandler:
9
  def __init__(self, path=""):
10
  # load the model
11
  tokenizer = AutoTokenizer.from_pretrained(path)
12
- model = AutoModel.from_pretrained(path, low_cpu_mem_usage=True)
 
 
13
  # create inference pipeline
14
- # Do I have to check device?
15
- self.pipeline = pipeline(
16
- "text-generation", model=model, tokenizer=tokenizer, device=device)
17
 
18
- # (Might have to adjust typing)
19
  def __call__(self, data: Any) -> List[List[Dict[str, float]]]:
20
  inputs = data.pop("inputs", data)
21
  parameters = data.pop("parameters", None)
 
1
  import torch
2
  from typing import Dict, List, Any
3
+ from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
4
 
5
  # check for GPU
6
  device = 0 if torch.cuda.is_available() else -1
7
 
8
+
9
  class EndpointHandler:
10
  def __init__(self, path=""):
11
  # load the model
12
  tokenizer = AutoTokenizer.from_pretrained(path)
13
+ # model = AutoModel.from_pretrained(path, low_cpu_mem_usage=True)
14
+ model = AutoModelForCausalLM.from_pretrained(path, low_cpu_mem_usage=True)
15
+ # model = AutoModelForSeq2SeqLM.from_pretrained(path, low_cpu_mem_usage=True)
16
  # create inference pipeline
17
+ self.pipeline = pipeline("text-generation", model=model, tokenizer=tokenizer, device=device)
 
 
18
 
 
19
  def __call__(self, data: Any) -> List[List[Dict[str, float]]]:
20
  inputs = data.pop("inputs", data)
21
  parameters = data.pop("parameters", None)
test.py DELETED
@@ -1,13 +0,0 @@
1
- from handler import EndpointHandler
2
-
3
- # init handler
4
- my_handler = EndpointHandler(path=".")
5
-
6
- # prepare sample payload
7
- payload = {"inputs": "I am Bob and I want to "}
8
-
9
- # test the handler
10
- pred=my_handler(payload)
11
-
12
- # show results
13
- print("pred", pred)