murphy / assets /Prediction.py.bak
cheesexuebao's picture
Modify tables
74b913c
raw
history blame
No virus
3.82 kB
### install the needed package
# !pip install transformers
# !pip install torchmetrics
# !pip3 install ogb pytorch_lightning -q
import pandas as pd
from tqdm.auto import tqdm
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, Dataset
from transformers import BertTokenizerFast as BertTokenizer, BertModel, AdamW, get_linear_schedule_with_warmup
# import pytorch_lightning as pl
pd.set_option('display.max_columns', 500)
RANDOM_SEED = 42
class ModelTagger(nn.Module):
def __init__(self, model_path="bert-base-uncased"):
super().__init__()
self.bert = BertModel.from_pretrained(model_path, return_dict=True)
self.classifier = nn.Linear(self.bert.config.hidden_size, 4)
self.criterion = nn.BCELoss()
def forward(self, input_ids, attention_mask, labels=None):
output = self.bert(input_ids, attention_mask=attention_mask)
output = self.classifier(output.pooler_output)
output = torch.sigmoid(output)
loss = 0
if labels is not None:
loss = self.criterion(output, labels)
return loss, output
class Predict_Dataset(Dataset):
def __init__(
self,
data: pd.DataFrame,
text_col: str,
tokenizer: BertTokenizer,
max_token_len: int = 128
):
self.text_col = text_col
self.tokenizer = tokenizer
self.data = data
self.max_token_len = max_token_len
def __len__(self):
return len(self.data)
def __getitem__(self, index: int):
data_row = self.data.iloc[index]
post = data_row[self.text_col]
encoding = self.tokenizer.encode_plus(
post,
add_special_tokens=True,
max_length=self.max_token_len,
return_token_type_ids=False,
padding="max_length",
truncation=True,
return_attention_mask=True,
return_tensors='pt',
)
return dict(
post=post,
input_ids=encoding["input_ids"].flatten(),
attention_mask=encoding["attention_mask"].flatten(),
)
def predict(data, text_col, tokenizer, model, device, LABEL_COLUMNS, max_token_len=128):
predictions = []
df_token = Predict_Dataset(data, text_col, tokenizer, max_token_len=max_token_len)
loader = DataLoader(df_token, batch_size=1000, num_workers=0)
for item in tqdm(loader):
_, prediction = model(
item["input_ids"].to(device),
item["attention_mask"].to(device)
)
predictions.append(prediction.detach().cpu())
final_pred = torch.cat(predictions, dim=0)
y_inten = final_pred.numpy().T
return {
LABEL_COLUMNS[0]: y_inten[0].tolist(),
LABEL_COLUMNS[1]: y_inten[1].tolist(),
LABEL_COLUMNS[2]: y_inten[2].tolist(),
LABEL_COLUMNS[3]: y_inten[3].tolist()
}
def get_result(df, result, LABEL_COLUMNS):
df[LABEL_COLUMNS[0]] = result[LABEL_COLUMNS[0]]
df[LABEL_COLUMNS[1]] = result[LABEL_COLUMNS[1]]
df[LABEL_COLUMNS[2]] = result[LABEL_COLUMNS[2]]
df[LABEL_COLUMNS[3]] = result[LABEL_COLUMNS[3]]
return df
Data = pd.read_csv("Kickstarter_sentence_level_5000.csv")
Data = Data[:20]
device = torch.device('cpu')
BERT_MODEL_NAME = 'bert-base-uncased'
tokenizer = BertTokenizer.from_pretrained(BERT_MODEL_NAME)
LABEL_COLUMNS = ["Assertive Tone", "Conversational Tone", "Emotional Tone", "Informative Tone"]
params = torch.load("checkpoints/Kickstarter.ckpt", map_location='cpu')['state_dict']
kick_model = ModelTagger()
kick_model.load_state_dict(params, strict=True)
kick_model.eval()
kick_model = kick_model.to(device)
kick_fk_doc_result = predict(Data,"content", tokenizer,kick_model, device, LABEL_COLUMNS)
fk_result = get_result(Data, kick_fk_doc_result, LABEL_COLUMNS)
fk_result.to_csv("output/prediction_origin_Kickstarter.csv")
# tab_output = gr.Label(label='Probability Predictions:', value=dict(zip(LABEL_COLUMNS, [0]*len(LABEL_COLUMNS))))