fwittel commited on
Commit
3a31a5b
1 Parent(s): 040e104

Added device-selection to handler.py

Browse files
Files changed (1) hide show
  1. handler.py +4 -1
handler.py CHANGED
@@ -1,6 +1,9 @@
 
1
  from typing import Dict, List, Any
2
  from transformers import AutoModel, AutoTokenizer, pipeline
3
 
 
 
4
 
5
  class EndpointHandler:
6
  def __init__(self, path=""):
@@ -10,7 +13,7 @@ class EndpointHandler:
10
  # create inference pipeline
11
  # Do I have to check device?
12
  self.pipeline = pipeline(
13
- "text-generation", model=model, tokenizer=tokenizer)
14
 
15
  # (Might have to adjust typing)
16
  def __call__(self, data: Any) -> List[List[Dict[str, float]]]:
 
1
+ import torch
2
  from typing import Dict, List, Any
3
  from transformers import AutoModel, AutoTokenizer, pipeline
4
 
5
+ # check for GPU
6
+ device = 0 if torch.cuda.is_available() else -1
7
 
8
  class EndpointHandler:
9
  def __init__(self, path=""):
 
13
  # create inference pipeline
14
  # Do I have to check device?
15
  self.pipeline = pipeline(
16
+ "text-generation", model=model, tokenizer=tokenizer, device=device)
17
 
18
  # (Might have to adjust typing)
19
  def __call__(self, data: Any) -> List[List[Dict[str, float]]]: