murphy / app.py
cheesexuebao's picture
Adding Blob
ae15986
raw
history blame
No virus
3.05 kB
import gradio as gr
import torch
from transformers import BertModel
from transformers import BertTokenizer
import torch.nn as nn
class BertSST2Model(nn.Module):
# 初始化类
def __init__(self, class_size, pretrained_name='bert-base-chinese'):
"""
Args:
class_size :指定分类模型的最终类别数目,以确定线性分类器的映射维度
pretrained_name :用以指定bert的预训练模型
"""
# 类继承的初始化,固定写法
super(BertSST2Model, self).__init__()
# 加载HuggingFace的BertModel
# BertModel的最终输出维度默认为768
# return_dict=True 可以使BertModel的输出具有dict属性,即以 bert_output['last_hidden_state'] 方式调用
self.bert = BertModel.from_pretrained(pretrained_name,
return_dict=True)
# 通过一个线性层将[CLS]标签对应的维度:768->class_size
# class_size 在SST-2情感分类任务中设置为:2
self.classifier = nn.Linear(768, class_size)
def forward(self, inputs):
# 获取DataLoader中已经处理好的输入数据:
# input_ids :tensor类型,shape=batch_size*max_len max_len为当前batch中的最大句长
# input_tyi :tensor类型,
# input_attn_mask :tensor类型,因为input_ids中存在大量[Pad]填充,attention mask将pad部分值置为0,让模型只关注非pad部分
input_ids, input_tyi, input_attn_mask = inputs['input_ids'], inputs[
'token_type_ids'], inputs['attention_mask']
# 将三者输入进模型,如果想知道模型内部如何运作,前面的蛆以后再来探索吧~
output = self.bert(input_ids, input_tyi, input_attn_mask)
# bert_output 分为两个部分:
# last_hidden_state:最后一个隐层的值
# pooler output:对应的是[CLS]的输出,用于分类任务
# 通过线性层将维度:768->2
# categories_numberic:tensor类型,shape=batch_size*class_size,用于后续的CrossEntropy计算
categories_numberic = self.classifier(output.pooler_output)
return categories_numberic
device = torch.device("cpu")
pretrained_model_name = './bert-base-uncased'
# 创建模型 BertSST2Model
model = BertSST2Model(2, pretrained_model_name)
# 固定写法,将模型加载到device上,
# 如果是GPU上运行,此时可以观察到GPU的显存增加
model.to(device)
# 加载预训练模型对应的tokenizer
tokenizer = BertTokenizer.from_pretrained(pretrained_model_name)
def modelscope_quickstart(sentence):
inputs = tokenizer(sentence,
padding=True,
truncation=True,
return_tensors="pt",
max_length=512)
output = model(inputs)
cate = output.argmax(dim=1)
return "分类结果为:" + str(0 if cate else 1)
demo = gr.Interface(fn=modelscope_quickstart, inputs="text", outputs="text")
demo.launch()