KLeedrug's picture
hopefully the bugs are removed
8e92634
# -*- coding: utf-8 -*-
"""scratchpad
Automatically generated by Colaboratory.
Original file is located at
https://colab.research.google.com/notebooks/empty.ipynb
"""
#!pip install gradio
#!pip install transformers tokenizers
import torch
from torch import nn
import torch.nn.functional as F
from transformers import AutoTokenizer, AutoModelWithLMHead
tokenizer = AutoTokenizer.from_pretrained('distilroberta-base')
tokenizer.save_pretrained("tokenizer")
# from https://github.com/digantamisra98/Mish/blob/b5f006660ac0b4c46e2c6958ad0301d7f9c59651/Mish/Torch/mish.py
@torch.jit.script
def mish(input):
return input * torch.tanh(F.softplus(input))
class Mish(nn.Module):
def forward(self, input):
return mish(input)
class NewEmoModel(nn.Module):
def __init__(self, base_model, n_classes=2, base_model_output_size=768, dropout=0.05):
super().__init__()
self.base_model = base_model
self.classifier = nn.Sequential(
nn.Dropout(dropout),
nn.Linear(base_model_output_size, base_model_output_size),
Mish(),
nn.Dropout(dropout),
nn.Linear(base_model_output_size, n_classes)
)
for layer in self.classifier:
if isinstance(layer, nn.Linear):
layer.weight.data.normal_(mean=0.0, std=0.02)
if layer.bias is not None:
layer.bias.data.zero_()
self.last_classifier = nn.Sequential(
nn.Dropout(dropout),
# n_classes: [V,A] -> 2
# 4: v_bar, v_std, a_bar, a_std
nn.Linear(2*n_classes+4, base_model_output_size),
Mish(),
nn.Dropout(dropout),
nn.Linear(base_model_output_size, n_classes)
)
def forward_roberta(self, input_, *args):
X, attention_mask = input_
hidden_states = self.base_model(X, attention_mask=attention_mask)
# maybe do some pooling / RNNs... go crazy here!
# use the <s> representation
return self.classifier(hidden_states[0][:, 0, :])
def forward(self, input_):
#X, atten_mask, V_bar, V_std, A_bar, A_stat = input_
# in1, in2 has X, atten_mask, respectively
in1, in2, V_bar, V_std, A_bar, A_stat = input_
VAsj = self.forward_roberta( in1 )
VAj_1 = self.forward_roberta( in2 )
# split VAs into VA from sj...sk and sj+1
#VAsj, VAj_1 = VAs[0], VAs[1]
# calculate new avg and std of V, A here
#V_new, A_new = 0, 0
return self.last_classifier(torch.concat([VAsj, VAj_1, V_bar, V_std, A_bar, A_stat]))
n_classes = 2
model = NewEmoModel(AutoModelWithLMHead.from_pretrained("distilroberta-base").base_model, n_classes)
model.eval()
# arr = ["句子1", "句子2", 0.16, 0, 0.5, 0]
def get_output(arr, ln):
with torch.no_grad():
# forward pass
# [sent1, sent2, V_avg, V_平方和, A_avg, A_平方和]
#arr = ["句子1", "句子2", 0.16, 0, 0.5, 0]
# initialize stats
#ln = 0 # the passed number of data
stats = torch.tensor( [ arr[2:] ] ) # expected shape: (1,4) (or, (batch, 4))
enc = tokenizer.encode_plus(arr[0])
a = (torch.tensor(enc["input_ids"]).unsqueeze(0), torch.tensor(enc["attention_mask"]).unsqueeze(0))
enc = tokenizer.encode_plus(arr[1])
b = (torch.tensor(enc["input_ids"]).unsqueeze(0), torch.tensor(enc["attention_mask"]).unsqueeze(0))
out1 = model.forward_roberta(a) # Sk...j
out2 = model.forward_roberta(b) # Sj+1
ln += out1.shape[0] # the batch_size
ratio = out1.shape[0] / ln
#in_f = torch.concat([out1, out2, stats[:,0:1], stats[:,1:2]**2/ln - stats[:,0:1], stats[:,2:3], stats[:,3:]**2/ln - stats[:,2:3]], dim=1)
# 把標準差 改成 變異數,符合我們train的方式
stats[0,1] = stats[0,1]**2
stats[0,3] = stats[0,3]**2
in_f = torch.concat([out1, out2, stats], dim=1)
output = model.last_classifier(in_f) # shape: (1,2) (or, (bs, 2))
return output
# # update average & standard deviation (sigma(x**2) actually)
# stats[:,0] = stats[:,0] * (1-ratio) + output[:,0] * ratio
# stats[:,2] = stats[:,2] * (1-ratio) + output[:,1] * ratio
# stats[:,1] = stats[:,1]* (1-ratio) + output[:,0] ** 2 * ratio
# stats[:,3] = stats[:,3]* (1-ratio) + output[:,1] ** 2 * ratio
# the map of pretrained weight
mp = {0: "0627_1_epoch_all_unfreezed.pt", 1: "0629_on_pseudo_right_1_epoch_all_unfreezed.pt"}
# arr = ["句子1", "句子2", 0.16, 0, 0.5, 0]
def fn(sent1, sent2, v_avg, v_sqr, a_avg, a_sqr, ln, pretrained_path_idx):
# load pretrained model
pretrained_path = mp[pretrained_path_idx]
model.load_state_dict(torch.load(pretrained_path,map_location=torch.device('cpu'))() )
# do the inference
arr = [sent1, sent2, v_avg, v_sqr, a_avg, a_sqr]
out = get_output(arr, ln)
return float(out[0,0]), float(out[0,1])
weight_description=[]
for k, v in mp.items():
weight_description.append(f"{k}: {v}")
# convert to string
weight_description = "\n".join(weight_description)
description = f"""here are the available weights, enter the index to choose from it, default to 0.\n
{weight_description}
"""
import gradio as gr
interface = gr.Interface(
fn = fn,
inputs=["text", "text", "number", "number", "number", "number", "number", "number"],
outputs=["number", "number"],
description=description
)
interface.launch()