Spaces:
Sleeping
Sleeping
# -*- coding: utf-8 -*- | |
# @Time : 2022/4/21 5:30 下午 | |
# @Author : JianingWang | |
# @File : span_proto.py | |
""" | |
This code is implemented for the paper ""SpanProto: A Two-stage Span-based Prototypical Network for Few-shot Named Entity Recognition"" | |
""" | |
import os | |
from typing import Optional | |
import torch | |
import numpy as np | |
import torch.nn as nn | |
from typing import Union | |
from dataclasses import dataclass | |
from torch.nn import BCEWithLogitsLoss | |
from transformers import MegatronBertModel, MegatronBertPreTrainedModel | |
from transformers.file_utils import ModelOutput | |
from transformers.models.bert import BertPreTrainedModel, BertModel | |
class TokenProtoOutput(ModelOutput): | |
loss: Optional[torch.FloatTensor] = None | |
logits: Optional[torch.FloatTensor] = None | |
class TokenProto(nn.Module): | |
def __init__(self, config): | |
""" | |
word_encoder: Sentence encoder | |
You need to set self.cost as your own loss function. | |
""" | |
nn.Module.__init__(self) | |
self.config = config | |
self.output_dir = "./outputs" | |
# self.predict_dir = self.predict_result_path(self.output_dir) | |
self.drop = nn.Dropout() | |
self.projector = nn.Sequential( # projector | |
nn.Linear(self.config.hidden_size, self.config.hidden_size), | |
nn.Sigmoid(), | |
# nn.LayerNorm(2) | |
) | |
self.tag_embeddings = nn.Embedding(2, self.config.hidden_size) # tag for labeled / unlabeled span set | |
# self.tag_mlp = nn.Linear(self.config.hidden_size, self.config.hidden_size) | |
self.max_length = 64 | |
self.margin_distance = 6.0 | |
self.global_step = 0 | |
def predict_result_path(self, path=None): | |
if path is None: | |
predict_dir = os.path.join( | |
self.output_dir, "{}-{}-{}".format(self.mode, self.num_class, self.num_example), "predict" | |
) | |
else: | |
predict_dir = os.path.join( | |
path, "predict" | |
) | |
# if os.path.exists(predict_dir): | |
# os.rmdir(predict_dir) # 删除历史记录 | |
if not os.path.exists(predict_dir): # 重新创建一个新的目录 | |
os.makedirs(predict_dir) | |
return predict_dir | |
def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], *model_args, **kwargs): | |
config = kwargs.pop("config", None) | |
model = TokenProto(config=config) | |
return model | |
def __dist__(self, x, y, dim): | |
if self.dot: | |
return (x * y).sum(dim) | |
else: | |
return -(torch.pow(x - y, 2)).sum(dim) | |
def __batch_dist__(self, S, Q, q_mask): | |
# S [class, embed_dim], Q [num_of_sent, num_of_tokens, embed_dim] | |
assert Q.size()[:2] == q_mask.size() | |
Q = Q[q_mask==1].view(-1, Q.size(-1)) # [num_of_all_text_tokens, embed_dim] | |
return self.__dist__(S.unsqueeze(0), Q.unsqueeze(1), 2) | |
def __get_proto__(self, embedding, tag, mask): | |
proto = [] | |
embedding = embedding[mask==1].view(-1, embedding.size(-1)) | |
tag = torch.cat(tag, 0) | |
assert tag.size(0) == embedding.size(0) | |
for label in range(torch.max(tag)+1): | |
proto.append(torch.mean(embedding[tag==label], 0)) | |
proto = torch.stack(proto) | |
return proto, embedding | |
def forward(self, support, query): | |
""" | |
support: Inputs of the support set. | |
query: Inputs of the query set. | |
N: Num of classes | |
K: Num of instances for each class in the support set | |
Q: Num of instances in the query set | |
support/query = {"index": [], "word": [], "mask": [], "label": [], "sentence_num": [], "text_mask": []} | |
""" | |
# support set和query set分别喂入BERT中获得各个样本的表示 | |
support_emb = self.word_encoder(support["word"], support["mask"]) # [num_sent, number_of_tokens, 768] | |
query_emb = self.word_encoder(query["word"], query["mask"]) # [num_sent, number_of_tokens, 768] | |
support_emb = self.drop(support_emb) | |
query_emb = self.drop(query_emb) | |
# Prototypical Networks | |
logits = [] | |
current_support_num = 0 | |
current_query_num = 0 | |
assert support_emb.size()[:2] == support["mask"].size() | |
assert query_emb.size()[:2] == query["mask"].size() | |
for i, sent_support_num in enumerate(support["sentence_num"]): # 遍历每个采样得到的N-way K-shot任务数据 | |
sent_query_num = query["sentence_num"][i] | |
# Calculate prototype for each class | |
# 因为一个batch里对应多个episode,因此 current_support_num:current_support_num+sent_support_num | |
# 用来表示当前输入的张量中,哪个范围内的句子属于当前N-way K-shot采样数据 | |
support_proto, embedding = self.__get_proto__( | |
support_emb[current_support_num:current_support_num+sent_support_num], | |
support["label"][current_support_num:current_support_num+sent_support_num], | |
support["text_mask"][current_support_num: current_support_num+sent_support_num]) | |
# calculate distance to each prototype | |
logits.append(self.__batch_dist__( | |
support_proto, | |
query_emb[current_query_num:current_query_num+sent_query_num], | |
query["text_mask"][current_query_num: current_query_num+sent_query_num])) # [num_of_query_tokens, class_num] | |
current_query_num += sent_query_num | |
current_support_num += sent_support_num | |
logits = torch.cat(logits, 0) # 每个query的从属于support set对应各个类的概率 | |
_, pred = torch.max(logits, 1) # 挑选最大概率对应的proto类作为预测结果 | |
# return logits, pred, embedding | |
return TokenProtoOutput( | |
logits=logits | |
) # 返回部分的所有logits不论最外层是list还是tuple,最里层一定要包含一个张量,否则huggingface里的nested_detach函数会报错 | |