moooji commited on
Commit
73e09fb
1 Parent(s): aabe3a4

Update handler.py

Browse files
Files changed (1) hide show
  1. handler.py +3 -3
handler.py CHANGED
@@ -9,7 +9,7 @@ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
9
 
10
  class EndpointHandler():
11
  def __init__(self, path=""):
12
- self.model = Swinv2Model.from_pretrained("microsoft/swinv2-large-patch4-window12to24-192to384-22kto1k-ft").to(device)
13
  self.processor = AutoImageProcessor.from_pretrained("microsoft/swinv2-large-patch4-window12to24-192to384-22kto1k-ft")
14
 
15
  def __call__(self, data: Any) -> List[float]:
@@ -21,6 +21,6 @@ class EndpointHandler():
21
  with torch.no_grad():
22
  outputs = self.model(**inputs)
23
 
24
- last_hidden_states = outputs.last_hidden_state
25
- return last_hidden_states.tolist()
26
 
 
9
 
10
  class EndpointHandler():
11
  def __init__(self, path=""):
12
+ self.model = Swinv2Model.from_pretrained("microsoft/swinv2-large-patch4-window12to24-192to384-22kto1k-ft", add_pooling_layer = True).to(device)
13
  self.processor = AutoImageProcessor.from_pretrained("microsoft/swinv2-large-patch4-window12to24-192to384-22kto1k-ft")
14
 
15
  def __call__(self, data: Any) -> List[float]:
 
21
  with torch.no_grad():
22
  outputs = self.model(**inputs)
23
 
24
+ pooler_output = outputs.pooler_output
25
+ return pooler_output.tolist()
26