samson-s commited on
Commit
da85024
1 Parent(s): 44cd819

Add gpu specification

Browse files
Files changed (1) hide show
  1. handler.py +2 -1
handler.py CHANGED
@@ -1,6 +1,6 @@
1
  from transformers import pipeline, AutoFeatureExtractor
2
  from PIL import Image
3
- import io
4
 
5
 
6
  class EndpointHandler:
@@ -9,6 +9,7 @@ class EndpointHandler:
9
  "image-to-text",
10
  model=path,
11
  feature_extractor=AutoFeatureExtractor,
 
12
  )
13
 
14
  def __call__(self, data) -> str:
 
1
  from transformers import pipeline, AutoFeatureExtractor
2
  from PIL import Image
3
+ import torch
4
 
5
 
6
  class EndpointHandler:
 
9
  "image-to-text",
10
  model=path,
11
  feature_extractor=AutoFeatureExtractor,
12
+ device=torch.device("cuda" if torch.cuda.is_available() else "cpu"),
13
  )
14
 
15
  def __call__(self, data) -> str: