Lee1112061 commited on
Commit
8ce1196
1 Parent(s): 8a11f35

Create predict.py

Browse files
Files changed (1) hide show
  1. predict.py +84 -0
predict.py ADDED
@@ -0,0 +1,84 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import shutil
2
+ import requests
3
+ import pandas as pd
4
+ import numpy as np
5
+ import seaborn as sns
6
+ import matplotlib.pyplot as plt
7
+ import torch.nn.functional as F
8
+ import torch
9
+ import streamlit as st
10
+ from torch import nn
11
+ from torch.optim import Adam
12
+ from tqdm import tqdm
13
+ from bs4 import BeautifulSoup
14
+ from sklearn.model_selection import train_test_split
15
+ from collections import defaultdict
16
+ from transformers import RobertaTokenizerFast, RobertaForQuestionAnswering
17
+
18
+ #加載模型參數
19
+ def load_ckp(checkpoint_fpath, model, optimizer):
20
+ """
21
+ checkpoint_path: path to save checkpoint
22
+ model: model that we want to load checkpoint parameters into
23
+ optimizer: optimizer we defined in previous training
24
+ """
25
+ # load check point
26
+ # 如果使用cpu則後面需加上 map_location=torch.device('cpu')
27
+ checkpoint = torch.load(checkpoint_fpath,map_location=torch.device('cpu'))
28
+ # initialize state_dict from checkpoint to model
29
+ model.load_state_dict(checkpoint['state_dict'])
30
+ # initialize optimizer from checkpoint to optimizer
31
+ optimizer.load_state_dict(checkpoint['optimizer'])
32
+ # initialize valid_loss_min from checkpoint to valid_loss_min
33
+ valid_loss_min = checkpoint['valid_loss_min']
34
+ # return model, optimizer, epoch value, min validation loss
35
+ return model, optimizer, checkpoint['epoch'], valid_loss_min
36
+
37
+ #使用GPU
38
+ device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
39
+
40
+ #預訓練模組
41
+ tokenizer = RobertaTokenizerFast.from_pretrained("roberta-base")
42
+ model = RobertaForQuestionAnswering.from_pretrained("roberta-base")
43
+
44
+ #優化器
45
+ LEARNING_RATE=3e-5
46
+ optimizer = torch.optim.Adam(params = model.parameters(), lr=LEARNING_RATE)
47
+
48
+
49
+ #加載最優參數
50
+ load_ckp('./model/best_model.pt', model, optimizer)
51
+
52
+
53
+ #預測函數
54
+ def Predict(question,context):
55
+ inputs = tokenizer(question,context,max_length= 512,padding='max_length',return_offsets_mapping=True,return_tensors="pt")
56
+ input_ids=inputs['input_ids'].to(device)
57
+ attention_mask =inputs['attention_mask'].to(device)
58
+ outputs = model(input_ids.reshape(1,512),attention_mask.reshape(1,512))
59
+ answer_start_index = outputs.start_logits.argmax(dim=1)
60
+ answer_end_index = outputs.end_logits.argmax(dim=1)
61
+
62
+ #predict answer
63
+ predict_answer = tokenizer.decode(inputs['input_ids'].flatten()[answer_start_index.item() : answer_end_index.item()])
64
+ return predict_answer.strip()
65
+
66
+
67
+ # qa = load_ckp('./model/best_model.pt', model, optimizer)
68
+ def main():
69
+ st.title("Question Answering")
70
+
71
+ with st.form("text_field"):
72
+ question = st.text_input("Enter some question :")
73
+ context = st.text_area("Enter some context :")
74
+
75
+ # clicked==True only when the button is clicked
76
+ clicked = st.form_submit_button("Submit")
77
+ if clicked:
78
+ results = Predict(question = question, context = context)
79
+ st.json(results)
80
+
81
+
82
+ if __name__ == "__main__":
83
+ main()
84
+