socialcomp commited on
Commit
b26a83f
1 Parent(s): d74d8bc

Update predict.py

Browse files

Add the process checking whether GPU or CPU can be used

Files changed (1) hide show
  1. predict.py +3 -3
predict.py CHANGED
@@ -19,7 +19,7 @@ from bs4 import BeautifulSoup
19
  # import pandas as pd
20
  # import numpy as np
21
  # import codecs
22
-
23
 
24
  #%% global変数として使う
25
  dict_key = {}
@@ -60,7 +60,7 @@ def predict_entities(modelpath, sentences_list, len_num_entity_type):
60
  # bert_tc = model.bert_tc.cuda()
61
 
62
  model = ner.BertForTokenClassification_pl(modelpath, num_labels=81, lr=1e-5)
63
- bert_tc = model.bert_tc.cuda()
64
 
65
  MODEL_NAME = 'cl-tohoku/bert-base-japanese-whole-word-masking'
66
  tokenizer = ner.NER_tokenizer_BIO.from_pretrained(
@@ -79,7 +79,7 @@ def predict_entities(modelpath, sentences_list, len_num_entity_type):
79
  encoding, spans = tokenizer.encode_plus_untagged(
80
  text, return_tensors='pt'
81
  )
82
- encoding = { k: v.cuda() for k, v in encoding.items() }
83
 
84
  with torch.no_grad():
85
  output = bert_tc(**encoding)
 
19
  # import pandas as pd
20
  # import numpy as np
21
  # import codecs
22
+ device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
23
 
24
  #%% global変数として使う
25
  dict_key = {}
 
60
  # bert_tc = model.bert_tc.cuda()
61
 
62
  model = ner.BertForTokenClassification_pl(modelpath, num_labels=81, lr=1e-5)
63
+ bert_tc = model.bert_tc.to(device)
64
 
65
  MODEL_NAME = 'cl-tohoku/bert-base-japanese-whole-word-masking'
66
  tokenizer = ner.NER_tokenizer_BIO.from_pretrained(
 
79
  encoding, spans = tokenizer.encode_plus_untagged(
80
  text, return_tensors='pt'
81
  )
82
+ encoding = { k: v.to(device) for k, v in encoding.items() }
83
 
84
  with torch.no_grad():
85
  output = bert_tc(**encoding)