File size: 810 Bytes
30cd0bc
 
 
 
 
ee8ab93
30cd0bc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
import torch
from transformers import MLukeTokenizer
from torch import nn 

tokenizer = MLukeTokenizer.from_pretrained('studio-ousia/luke-japanese-base-lite')
model = torch.load('C:\\[modelのあるディレクトリ]\\My_luke_model_pn.pth')

text=input()

encoded_dict = tokenizer.encode_plus(
                        text,                     
                        return_attention_mask = True,   # Attention maksの作成
                        return_tensors = 'pt',     #  Pytorch tensorsで返す
                )

pre = model(encoded_dict['input_ids'], token_type_ids=None, attention_mask=encoded_dict['attention_mask'])
SOFTMAX=nn.Softmax(dim=0)
num=SOFTMAX(pre.logits[0])
if num[1]>0.5:
    print(str(num[1]))
    print('ポジティブ')
else:
    print(str(num[1]))
    print('ネガティブ')