letrunglinh commited on
Commit
dd58cce
1 Parent(s): 31f75a0

Update pairwise_model.py

Browse files
Files changed (1) hide show
  1. pairwise_model.py +8 -3
pairwise_model.py CHANGED
@@ -8,7 +8,7 @@ from optimum.intel import OVModelForQuestionAnswering
8
  import openvino.inference_engine as ie
9
  import os
10
  import gradio as gr
11
-
12
  AUTH_TOKEN = "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"
13
 
14
  tokenizer = AutoTokenizer.from_pretrained('nguyenvulebinh/vi-mrc-base',
@@ -35,7 +35,12 @@ class PairwiseModel_modify(nn.Module):
35
 
36
  def forward(self, ids, masks):
37
  # Export the model to ONNX format
38
- input_feed = {"input_ids": ids.cpu().numpy().astype(np.int64), "attention_mask": masks.cpu().numpy().astype(np.int64)}
 
 
 
 
 
39
  # Specify the input shapes (batch_size, max_sequence_length)
40
  input_shapes = {"input_ids": ids.shape, "attention_mask": masks.shape}
41
 
@@ -57,7 +62,7 @@ class PairwiseModel_modify(nn.Module):
57
  tmp["question"] = question
58
  valid_dataset = SiameseDatasetStage1(tmp, tokenizer, self.max_length, is_test=True)
59
  valid_loader = DataLoader(valid_dataset, batch_size=self.batch_size, collate_fn=collate_fn,
60
- num_workers=0, shuffle=False, pin_memory=True)
61
  preds = []
62
  with torch.no_grad():
63
  bar = enumerate(valid_loader)
 
8
  import openvino.inference_engine as ie
9
  import os
10
  import gradio as gr
11
+ from multiprocessing import cpu_count
12
  AUTH_TOKEN = "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"
13
 
14
  tokenizer = AutoTokenizer.from_pretrained('nguyenvulebinh/vi-mrc-base',
 
35
 
36
  def forward(self, ids, masks):
37
  # Export the model to ONNX format
38
+ ids_np = ids.cpu().numpy().astype(np.int64)
39
+ masks_np = masks.cpu().numpy().astype(np.int64)
40
+ ids_device = torch.from_numpy(ids_np).to(self.device)
41
+ masks_device = torch.from_numpy(masks_np).to(self.device)
42
+
43
+ input_feed = {"input_ids": ids_device, "attention_mask": masks_device}
44
  # Specify the input shapes (batch_size, max_sequence_length)
45
  input_shapes = {"input_ids": ids.shape, "attention_mask": masks.shape}
46
 
 
62
  tmp["question"] = question
63
  valid_dataset = SiameseDatasetStage1(tmp, tokenizer, self.max_length, is_test=True)
64
  valid_loader = DataLoader(valid_dataset, batch_size=self.batch_size, collate_fn=collate_fn,
65
+ num_workers=cpu_count(), shuffle=False, pin_memory=True)
66
  preds = []
67
  with torch.no_grad():
68
  bar = enumerate(valid_loader)