tomiwa1a commited on
Commit
010ad9e
1 Parent(s): 230907f

use device_number for setting GPU

Browse files

avoid error: Expected a torch.device with a specified index or an integer, but got:cuda

Files changed (1) hide show
  1. handler.py +26 -19
handler.py CHANGED
@@ -18,15 +18,18 @@ class EndpointHandler():
18
  SENTENCE_TRANSFORMER_MODEL_NAME = "multi-qa-mpnet-base-dot-v1"
19
  QUESTION_ANSWER_MODEL_NAME = "vblagoje/bart_lfqa"
20
  SUMMARIZER_MODEL_NAME = "philschmid/bart-large-cnn-samsum"
 
 
21
 
22
  def __init__(self, path=""):
23
 
24
  device = "cuda" if torch.cuda.is_available() else "cpu"
25
  device_number = 0 if torch.cuda.is_available() else -1
26
  print(f'whisper and question_answer_model will use: {device}')
 
27
 
28
  t0 = time.time()
29
- self.whisper_model = whisper.load_model(self.WHISPER_MODEL_NAME).to(device)
30
  t1 = time.time()
31
 
32
  total = t1 - t0
@@ -45,10 +48,11 @@ class EndpointHandler():
45
 
46
  total = t1 - t0
47
  print(f'Finished loading summarizer in {total} seconds')
48
-
49
  self.question_answer_tokenizer = AutoTokenizer.from_pretrained(self.QUESTION_ANSWER_MODEL_NAME)
50
  t0 = time.time()
51
- self.question_answer_model = AutoModelForSeq2SeqLM.from_pretrained(self.QUESTION_ANSWER_MODEL_NAME).to(device)
 
52
  t1 = time.time()
53
  total = t1 - t0
54
  print(f'Finished loading question_answer_model in {total} seconds')
@@ -199,22 +203,25 @@ class EndpointHandler():
199
  conditioned_doc = "<P> " + " <P> ".join([d for d in documents])
200
  query_and_docs = "question: {} context: {}".format(query, conditioned_doc)
201
 
202
- model_input = self.question_answer_tokenizer(query_and_docs, truncation=False, padding=True, return_tensors="pt")
203
-
204
- generated_answers_encoded = self.question_answer_model.generate(input_ids=model_input["input_ids"].to(self.device),
205
- attention_mask=model_input["attention_mask"].to(self.device),
206
- min_length=64,
207
- max_length=256,
208
- do_sample=False,
209
- early_stopping=True,
210
- num_beams=8,
211
- temperature=1.0,
212
- top_k=None,
213
- top_p=None,
214
- eos_token_id=self.question_answer_tokenizer.eos_token_id,
215
- no_repeat_ngram_size=3,
216
- num_return_sequences=1)
217
- answer = self.question_answer_tokenizer.batch_decode(generated_answers_encoded, skip_special_tokens=True,clean_up_tokenization_spaces=True)
 
 
 
218
  return answer
219
 
220
  @staticmethod
 
18
  SENTENCE_TRANSFORMER_MODEL_NAME = "multi-qa-mpnet-base-dot-v1"
19
  QUESTION_ANSWER_MODEL_NAME = "vblagoje/bart_lfqa"
20
  SUMMARIZER_MODEL_NAME = "philschmid/bart-large-cnn-samsum"
21
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
22
+ device_number = 0 if torch.cuda.is_available() else -1
23
 
24
  def __init__(self, path=""):
25
 
26
  device = "cuda" if torch.cuda.is_available() else "cpu"
27
  device_number = 0 if torch.cuda.is_available() else -1
28
  print(f'whisper and question_answer_model will use: {device}')
29
+ print(f'whisper and question_answer_model will use device_number: {device_number}')
30
 
31
  t0 = time.time()
32
+ self.whisper_model = whisper.load_model(self.WHISPER_MODEL_NAME).to(device_number)
33
  t1 = time.time()
34
 
35
  total = t1 - t0
 
48
 
49
  total = t1 - t0
50
  print(f'Finished loading summarizer in {total} seconds')
51
+
52
  self.question_answer_tokenizer = AutoTokenizer.from_pretrained(self.QUESTION_ANSWER_MODEL_NAME)
53
  t0 = time.time()
54
+ self.question_answer_model = AutoModelForSeq2SeqLM.from_pretrained \
55
+ (self.QUESTION_ANSWER_MODEL_NAME).to(device_number)
56
  t1 = time.time()
57
  total = t1 - t0
58
  print(f'Finished loading question_answer_model in {total} seconds')
 
203
  conditioned_doc = "<P> " + " <P> ".join([d for d in documents])
204
  query_and_docs = "question: {} context: {}".format(query, conditioned_doc)
205
 
206
+ model_input = self.question_answer_tokenizer(query_and_docs, truncation=False, padding=True,
207
+ return_tensors="pt")
208
+
209
+ generated_answers_encoded = self.question_answer_model.generate(
210
+ input_ids=model_input["input_ids"].to(self.device),
211
+ attention_mask=model_input["attention_mask"].to(self.device),
212
+ min_length=64,
213
+ max_length=256,
214
+ do_sample=False,
215
+ early_stopping=True,
216
+ num_beams=8,
217
+ temperature=1.0,
218
+ top_k=None,
219
+ top_p=None,
220
+ eos_token_id=self.question_answer_tokenizer.eos_token_id,
221
+ no_repeat_ngram_size=3,
222
+ num_return_sequences=1)
223
+ answer = self.question_answer_tokenizer.batch_decode(generated_answers_encoded, skip_special_tokens=True,
224
+ clean_up_tokenization_spaces=True)
225
  return answer
226
 
227
  @staticmethod