DeepLearning101's picture
Upload 2 files
6c0ee22
raw
history blame
6.04 kB
# -*- coding: utf-8 -*-
# @Time : 2022/4/21 5:30 下午
# @Author : JianingWang
# @File : span_proto.py
"""
This code is implemented for the paper ""SpanProto: A Two-stage Span-based Prototypical Network for Few-shot Named Entity Recognition""
"""
import os
from typing import Optional
import torch
import numpy as np
import torch.nn as nn
from typing import Union
from dataclasses import dataclass
from torch.nn import BCEWithLogitsLoss
from transformers import MegatronBertModel, MegatronBertPreTrainedModel
from transformers.file_utils import ModelOutput
from transformers.models.bert import BertPreTrainedModel, BertModel
@dataclass
class TokenProtoOutput(ModelOutput):
loss: Optional[torch.FloatTensor] = None
logits: Optional[torch.FloatTensor] = None
class TokenProto(nn.Module):
def __init__(self, config):
"""
word_encoder: Sentence encoder
You need to set self.cost as your own loss function.
"""
nn.Module.__init__(self)
self.config = config
self.output_dir = "./outputs"
# self.predict_dir = self.predict_result_path(self.output_dir)
self.drop = nn.Dropout()
self.projector = nn.Sequential( # projector
nn.Linear(self.config.hidden_size, self.config.hidden_size),
nn.Sigmoid(),
# nn.LayerNorm(2)
)
self.tag_embeddings = nn.Embedding(2, self.config.hidden_size) # tag for labeled / unlabeled span set
# self.tag_mlp = nn.Linear(self.config.hidden_size, self.config.hidden_size)
self.max_length = 64
self.margin_distance = 6.0
self.global_step = 0
def predict_result_path(self, path=None):
if path is None:
predict_dir = os.path.join(
self.output_dir, "{}-{}-{}".format(self.mode, self.num_class, self.num_example), "predict"
)
else:
predict_dir = os.path.join(
path, "predict"
)
# if os.path.exists(predict_dir):
# os.rmdir(predict_dir) # 删除历史记录
if not os.path.exists(predict_dir): # 重新创建一个新的目录
os.makedirs(predict_dir)
return predict_dir
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], *model_args, **kwargs):
config = kwargs.pop("config", None)
model = TokenProto(config=config)
return model
def __dist__(self, x, y, dim):
if self.dot:
return (x * y).sum(dim)
else:
return -(torch.pow(x - y, 2)).sum(dim)
def __batch_dist__(self, S, Q, q_mask):
# S [class, embed_dim], Q [num_of_sent, num_of_tokens, embed_dim]
assert Q.size()[:2] == q_mask.size()
Q = Q[q_mask==1].view(-1, Q.size(-1)) # [num_of_all_text_tokens, embed_dim]
return self.__dist__(S.unsqueeze(0), Q.unsqueeze(1), 2)
def __get_proto__(self, embedding, tag, mask):
proto = []
embedding = embedding[mask==1].view(-1, embedding.size(-1))
tag = torch.cat(tag, 0)
assert tag.size(0) == embedding.size(0)
for label in range(torch.max(tag)+1):
proto.append(torch.mean(embedding[tag==label], 0))
proto = torch.stack(proto)
return proto, embedding
def forward(self, support, query):
"""
support: Inputs of the support set.
query: Inputs of the query set.
N: Num of classes
K: Num of instances for each class in the support set
Q: Num of instances in the query set
support/query = {"index": [], "word": [], "mask": [], "label": [], "sentence_num": [], "text_mask": []}
"""
# support set和query set分别喂入BERT中获得各个样本的表示
support_emb = self.word_encoder(support["word"], support["mask"]) # [num_sent, number_of_tokens, 768]
query_emb = self.word_encoder(query["word"], query["mask"]) # [num_sent, number_of_tokens, 768]
support_emb = self.drop(support_emb)
query_emb = self.drop(query_emb)
# Prototypical Networks
logits = []
current_support_num = 0
current_query_num = 0
assert support_emb.size()[:2] == support["mask"].size()
assert query_emb.size()[:2] == query["mask"].size()
for i, sent_support_num in enumerate(support["sentence_num"]): # 遍历每个采样得到的N-way K-shot任务数据
sent_query_num = query["sentence_num"][i]
# Calculate prototype for each class
# 因为一个batch里对应多个episode,因此 current_support_num:current_support_num+sent_support_num
# 用来表示当前输入的张量中,哪个范围内的句子属于当前N-way K-shot采样数据
support_proto, embedding = self.__get_proto__(
support_emb[current_support_num:current_support_num+sent_support_num],
support["label"][current_support_num:current_support_num+sent_support_num],
support["text_mask"][current_support_num: current_support_num+sent_support_num])
# calculate distance to each prototype
logits.append(self.__batch_dist__(
support_proto,
query_emb[current_query_num:current_query_num+sent_query_num],
query["text_mask"][current_query_num: current_query_num+sent_query_num])) # [num_of_query_tokens, class_num]
current_query_num += sent_query_num
current_support_num += sent_support_num
logits = torch.cat(logits, 0) # 每个query的从属于support set对应各个类的概率
_, pred = torch.max(logits, 1) # 挑选最大概率对应的proto类作为预测结果
# return logits, pred, embedding
return TokenProtoOutput(
logits=logits
) # 返回部分的所有logits不论最外层是list还是tuple,最里层一定要包含一个张量,否则huggingface里的nested_detach函数会报错