Adapters
Inference Endpoints
jeremyarancio commited on
Commit
7bf309f
1 Parent(s): f85d258

Update handler

Browse files
Files changed (1) hide show
  1. handler.py +5 -3
handler.py CHANGED
@@ -3,10 +3,12 @@ import logging
3
 
4
  from transformers import AutoModelForCausalLM, AutoTokenizer
5
  from peft import PeftConfig, PeftModel
 
6
 
7
 
8
  LOGGER = logging.getLogger(__name__)
9
  logging.basicConfig(level=logging.INFO)
 
10
 
11
 
12
  class EndpointHandler():
@@ -29,13 +31,13 @@ class EndpointHandler():
29
  if prompt is None:
30
  raise ValueError("Missing prompt.")
31
  # Preprocess
32
- inputs = self.tokenizer(prompt, return_tensors="pt")
33
  # Forward
34
  LOGGER.info(f"Start generation.")
35
  if parameters is not None:
36
- output = self.model.generate(**inputs, **parameters)
37
  else:
38
- output = self.model.generate(**inputs)
39
  # Postprocess
40
  prediction = self.tokenizer.decode(output[0])
41
  LOGGER.info(f"Generated text: {prediction}")
 
3
 
4
  from transformers import AutoModelForCausalLM, AutoTokenizer
5
  from peft import PeftConfig, PeftModel
6
+ import torch.cuda
7
 
8
 
9
  LOGGER = logging.getLogger(__name__)
10
  logging.basicConfig(level=logging.INFO)
11
+ device = "cuda" if torch.cuda.is_available() else "cpu"
12
 
13
 
14
  class EndpointHandler():
 
31
  if prompt is None:
32
  raise ValueError("Missing prompt.")
33
  # Preprocess
34
+ input_ids = self.tokenizer(prompt, return_tensors="pt").input_ids.to(device)
35
  # Forward
36
  LOGGER.info(f"Start generation.")
37
  if parameters is not None:
38
+ output = self.model.generate(input_ids=input_ids, **parameters)
39
  else:
40
+ output = self.model.generate(input_ids=input_ids)
41
  # Postprocess
42
  prediction = self.tokenizer.decode(output[0])
43
  LOGGER.info(f"Generated text: {prediction}")