Spaces:
Running
Running
import argparse | |
import json | |
from pathlib import Path | |
import re | |
from typing import Dict, Optional, Union | |
import torch | |
import torch.nn.functional as F | |
from modules.layers import LstmSeq2SeqEncoder | |
from modules.base import InstructBase | |
from modules.evaluator import Evaluator, greedy_search | |
from modules.span_rep import SpanRepLayer | |
from modules.token_rep import TokenRepLayer | |
from torch import nn | |
from torch.nn.utils.rnn import pad_sequence | |
from huggingface_hub import PyTorchModelHubMixin, hf_hub_download | |
from huggingface_hub.utils import HfHubHTTPError | |
class GLiNER(InstructBase, PyTorchModelHubMixin): | |
def __init__(self, config): | |
super().__init__(config) | |
self.config = config | |
# [ENT] token | |
self.entity_token = "<<ENT>>" | |
self.sep_token = "<<SEP>>" | |
# usually a pretrained bidirectional transformer, returns first subtoken representation | |
self.token_rep_layer = TokenRepLayer(model_name=config.model_name, fine_tune=config.fine_tune, | |
subtoken_pooling=config.subtoken_pooling, hidden_size=config.hidden_size, | |
add_tokens=[self.entity_token, self.sep_token]) | |
# hierarchical representation of tokens | |
self.rnn = LstmSeq2SeqEncoder( | |
input_size=config.hidden_size, | |
hidden_size=config.hidden_size // 2, | |
num_layers=1, | |
bidirectional=True, | |
) | |
# span representation | |
self.span_rep_layer = SpanRepLayer( | |
span_mode=config.span_mode, | |
hidden_size=config.hidden_size, | |
max_width=config.max_width, | |
dropout=config.dropout, | |
) | |
# prompt representation (FFN) | |
self.prompt_rep_layer = nn.Sequential( | |
nn.Linear(config.hidden_size, config.hidden_size * 4), | |
nn.Dropout(config.dropout), | |
nn.ReLU(), | |
nn.Linear(config.hidden_size * 4, config.hidden_size) | |
) | |
def compute_score_train(self, x): | |
span_idx = x['span_idx'] * x['span_mask'].unsqueeze(-1) | |
new_length = x['seq_length'].clone() | |
new_tokens = [] | |
all_len_prompt = [] | |
num_classes_all = [] | |
# add prompt to the tokens | |
for i in range(len(x['tokens'])): | |
all_types_i = list(x['classes_to_id'][i].keys()) | |
# multiple entity types in all_types. Prompt is appended at the start of tokens | |
entity_prompt = [] | |
num_classes_all.append(len(all_types_i)) | |
# add enity types to prompt | |
for entity_type in all_types_i: | |
entity_prompt.append(self.entity_token) # [ENT] token | |
entity_prompt.append(entity_type) # entity type | |
entity_prompt.append(self.sep_token) # [SEP] token | |
# prompt format: | |
# [ENT] entity_type [ENT] entity_type ... [ENT] entity_type [SEP] | |
# add prompt to the tokens | |
tokens_p = entity_prompt + x['tokens'][i] | |
# input format: | |
# [ENT] entity_type_1 [ENT] entity_type_2 ... [ENT] entity_type_m [SEP] token_1 token_2 ... token_n | |
# update length of the sequence (add prompt length to the original length) | |
new_length[i] = new_length[i] + len(entity_prompt) | |
# update tokens | |
new_tokens.append(tokens_p) | |
# store prompt length | |
all_len_prompt.append(len(entity_prompt)) | |
# create a mask using num_classes_all (0, if it exceeds the number of classes, 1 otherwise) | |
max_num_classes = max(num_classes_all) | |
entity_type_mask = torch.arange(max_num_classes).unsqueeze(0).expand(len(num_classes_all), -1).to( | |
x['span_mask'].device) | |
entity_type_mask = entity_type_mask < torch.tensor(num_classes_all).unsqueeze(-1).to( | |
x['span_mask'].device) # [batch_size, max_num_classes] | |
# compute all token representations | |
bert_output = self.token_rep_layer(new_tokens, new_length) | |
word_rep_w_prompt = bert_output["embeddings"] # embeddings for all tokens (with prompt) | |
mask_w_prompt = bert_output["mask"] # mask for all tokens (with prompt) | |
# get word representation (after [SEP]), mask (after [SEP]) and entity type representation (before [SEP]) | |
word_rep = [] # word representation (after [SEP]) | |
mask = [] # mask (after [SEP]) | |
entity_type_rep = [] # entity type representation (before [SEP]) | |
for i in range(len(x['tokens'])): | |
prompt_entity_length = all_len_prompt[i] # length of prompt for this example | |
# get word representation (after [SEP]) | |
word_rep.append(word_rep_w_prompt[i, prompt_entity_length:prompt_entity_length + x['seq_length'][i]]) | |
# get mask (after [SEP]) | |
mask.append(mask_w_prompt[i, prompt_entity_length:prompt_entity_length + x['seq_length'][i]]) | |
# get entity type representation (before [SEP]) | |
entity_rep = word_rep_w_prompt[i, :prompt_entity_length - 1] # remove [SEP] | |
entity_rep = entity_rep[0::2] # it means that we take every second element starting from the second one | |
entity_type_rep.append(entity_rep) | |
# padding for word_rep, mask and entity_type_rep | |
word_rep = pad_sequence(word_rep, batch_first=True) # [batch_size, seq_len, hidden_size] | |
mask = pad_sequence(mask, batch_first=True) # [batch_size, seq_len] | |
entity_type_rep = pad_sequence(entity_type_rep, batch_first=True) # [batch_size, len_types, hidden_size] | |
# compute span representation | |
word_rep = self.rnn(word_rep, mask) | |
span_rep = self.span_rep_layer(word_rep, span_idx) | |
# compute final entity type representation (FFN) | |
entity_type_rep = self.prompt_rep_layer(entity_type_rep) # (batch_size, len_types, hidden_size) | |
num_classes = entity_type_rep.shape[1] # number of entity types | |
# similarity score | |
scores = torch.einsum('BLKD,BCD->BLKC', span_rep, entity_type_rep) | |
return scores, num_classes, entity_type_mask | |
def forward(self, x): | |
# compute span representation | |
scores, num_classes, entity_type_mask = self.compute_score_train(x) | |
batch_size = scores.shape[0] | |
# loss for filtering classifier | |
logits_label = scores.view(-1, num_classes) | |
labels = x["span_label"].view(-1) # (batch_size * num_spans) | |
mask_label = labels != -1 # (batch_size * num_spans) | |
labels.masked_fill_(~mask_label, 0) # Set the labels of padding tokens to 0 | |
# one-hot encoding | |
labels_one_hot = torch.zeros(labels.size(0), num_classes + 1, dtype=torch.float32).to(scores.device) | |
labels_one_hot.scatter_(1, labels.unsqueeze(1), 1) # Set the corresponding index to 1 | |
labels_one_hot = labels_one_hot[:, 1:] # Remove the first column | |
# Shape of labels_one_hot: (batch_size * num_spans, num_classes) | |
# compute loss (without reduction) | |
all_losses = F.binary_cross_entropy_with_logits(logits_label, labels_one_hot, | |
reduction='none') | |
# mask loss using entity_type_mask (B, C) | |
masked_loss = all_losses.view(batch_size, -1, num_classes) * entity_type_mask.unsqueeze(1) | |
all_losses = masked_loss.view(-1, num_classes) | |
# expand mask_label to all_losses | |
mask_label = mask_label.unsqueeze(-1).expand_as(all_losses) | |
# put lower loss for in label_one_hot (2 for positive, 1 for negative) | |
weight_c = labels_one_hot + 1 | |
# apply mask | |
all_losses = all_losses * mask_label.float() * weight_c | |
return all_losses.sum() | |
def compute_score_eval(self, x, device): | |
# check if classes_to_id is dict | |
assert isinstance(x['classes_to_id'], dict), "classes_to_id must be a dict" | |
span_idx = (x['span_idx'] * x['span_mask'].unsqueeze(-1)).to(device) | |
all_types = list(x['classes_to_id'].keys()) | |
# multiple entity types in all_types. Prompt is appended at the start of tokens | |
entity_prompt = [] | |
# add enity types to prompt | |
for entity_type in all_types: | |
entity_prompt.append(self.entity_token) | |
entity_prompt.append(entity_type) | |
entity_prompt.append(self.sep_token) | |
prompt_entity_length = len(entity_prompt) | |
# add prompt | |
tokens_p = [entity_prompt + tokens for tokens in x['tokens']] | |
seq_length_p = x['seq_length'] + prompt_entity_length | |
out = self.token_rep_layer(tokens_p, seq_length_p) | |
word_rep_w_prompt = out["embeddings"] | |
mask_w_prompt = out["mask"] | |
# remove prompt | |
word_rep = word_rep_w_prompt[:, prompt_entity_length:, :] | |
mask = mask_w_prompt[:, prompt_entity_length:] | |
# get_entity_type_rep | |
entity_type_rep = word_rep_w_prompt[:, :prompt_entity_length - 1, :] | |
# extract [ENT] tokens (which are at even positions in entity_type_rep) | |
entity_type_rep = entity_type_rep[:, 0::2, :] | |
entity_type_rep = self.prompt_rep_layer(entity_type_rep) # (batch_size, len_types, hidden_size) | |
word_rep = self.rnn(word_rep, mask) | |
span_rep = self.span_rep_layer(word_rep, span_idx) | |
local_scores = torch.einsum('BLKD,BCD->BLKC', span_rep, entity_type_rep) | |
return local_scores | |
def predict(self, x, flat_ner=False, threshold=0.5): | |
self.eval() | |
local_scores = self.compute_score_eval(x, device=next(self.parameters()).device) | |
spans = [] | |
for i, _ in enumerate(x["tokens"]): | |
local_i = local_scores[i] | |
wh_i = [i.tolist() for i in torch.where(torch.sigmoid(local_i) > threshold)] | |
span_i = [] | |
for s, k, c in zip(*wh_i): | |
if s + k < len(x["tokens"][i]): | |
span_i.append((s, s + k, x["id_to_classes"][c + 1], local_i[s, k, c])) | |
span_i = greedy_search(span_i, flat_ner) | |
spans.append(span_i) | |
return spans | |
def predict_entities(self, text, labels, flat_ner=True, threshold=0.5): | |
tokens = [] | |
start_token_idx_to_text_idx = [] | |
end_token_idx_to_text_idx = [] | |
for match in re.finditer(r'\w+(?:[-_]\w+)*|\S', text): | |
tokens.append(match.group()) | |
start_token_idx_to_text_idx.append(match.start()) | |
end_token_idx_to_text_idx.append(match.end()) | |
input_x = {"tokenized_text": tokens, "ner": None} | |
x = self.collate_fn([input_x], labels) | |
output = self.predict(x, flat_ner=flat_ner, threshold=threshold) | |
entities = [] | |
for start_token_idx, end_token_idx, ent_type in output[0]: | |
start_text_idx = start_token_idx_to_text_idx[start_token_idx] | |
end_text_idx = end_token_idx_to_text_idx[end_token_idx] | |
entities.append({ | |
"start": start_token_idx_to_text_idx[start_token_idx], | |
"end": end_token_idx_to_text_idx[end_token_idx], | |
"text": text[start_text_idx:end_text_idx], | |
"label": ent_type, | |
}) | |
return entities | |
def evaluate(self, test_data, flat_ner=False, threshold=0.5, batch_size=12, entity_types=None): | |
self.eval() | |
data_loader = self.create_dataloader(test_data, batch_size=batch_size, entity_types=entity_types, shuffle=False) | |
device = next(self.parameters()).device | |
all_preds = [] | |
all_trues = [] | |
for x in data_loader: | |
for k, v in x.items(): | |
if isinstance(v, torch.Tensor): | |
x[k] = v.to(device) | |
batch_predictions = self.predict(x, flat_ner, threshold) | |
all_preds.extend(batch_predictions) | |
all_trues.extend(x["entities"]) | |
evaluator = Evaluator(all_trues, all_preds) | |
out, f1 = evaluator.evaluate() | |
return out, f1 | |
def _from_pretrained( | |
cls, | |
*, | |
model_id: str, | |
revision: Optional[str], | |
cache_dir: Optional[Union[str, Path]], | |
force_download: bool, | |
proxies: Optional[Dict], | |
resume_download: bool, | |
local_files_only: bool, | |
token: Union[str, bool, None], | |
map_location: str = "cpu", | |
strict: bool = False, | |
**model_kwargs, | |
): | |
# 1. Backwards compatibility: Use "gliner_base.pt" and "gliner_multi.pt" with all data | |
filenames = ["gliner_base.pt", "gliner_multi.pt"] | |
for filename in filenames: | |
model_file = Path(model_id) / filename | |
if not model_file.exists(): | |
try: | |
model_file = hf_hub_download( | |
repo_id=model_id, | |
filename=filename, | |
revision=revision, | |
cache_dir=cache_dir, | |
force_download=force_download, | |
proxies=proxies, | |
resume_download=resume_download, | |
token=token, | |
local_files_only=local_files_only, | |
) | |
except HfHubHTTPError: | |
continue | |
dict_load = torch.load(model_file, map_location=torch.device(map_location)) | |
config = dict_load["config"] | |
state_dict = dict_load["model_weights"] | |
config.model_name = "microsoft/deberta-v3-base" if filename == "gliner_base.pt" else "microsoft/mdeberta-v3-base" | |
model = cls(config) | |
model.load_state_dict(state_dict, strict=strict, assign=True) | |
# Required to update flair's internals as well: | |
model.to(map_location) | |
return model | |
# 2. Newer format: Use "pytorch_model.bin" and "gliner_config.json" | |
from train import load_config_as_namespace | |
model_file = Path(model_id) / "pytorch_model.bin" | |
if not model_file.exists(): | |
model_file = hf_hub_download( | |
repo_id=model_id, | |
filename="pytorch_model.bin", | |
revision=revision, | |
cache_dir=cache_dir, | |
force_download=force_download, | |
proxies=proxies, | |
resume_download=resume_download, | |
token=token, | |
local_files_only=local_files_only, | |
) | |
config_file = Path(model_id) / "gliner_config.json" | |
if not config_file.exists(): | |
config_file = hf_hub_download( | |
repo_id=model_id, | |
filename="gliner_config.json", | |
revision=revision, | |
cache_dir=cache_dir, | |
force_download=force_download, | |
proxies=proxies, | |
resume_download=resume_download, | |
token=token, | |
local_files_only=local_files_only, | |
) | |
config = load_config_as_namespace(config_file) | |
model = cls(config) | |
state_dict = torch.load(model_file, map_location=torch.device(map_location)) | |
model.load_state_dict(state_dict, strict=strict, assign=True) | |
model.to(map_location) | |
return model | |
def save_pretrained( | |
self, | |
save_directory: Union[str, Path], | |
*, | |
config: Optional[Union[dict, "DataclassInstance"]] = None, | |
repo_id: Optional[str] = None, | |
push_to_hub: bool = False, | |
**push_to_hub_kwargs, | |
) -> Optional[str]: | |
""" | |
Save weights in local directory. | |
Args: | |
save_directory (`str` or `Path`): | |
Path to directory in which the model weights and configuration will be saved. | |
config (`dict` or `DataclassInstance`, *optional*): | |
Model configuration specified as a key/value dictionary or a dataclass instance. | |
push_to_hub (`bool`, *optional*, defaults to `False`): | |
Whether or not to push your model to the Huggingface Hub after saving it. | |
repo_id (`str`, *optional*): | |
ID of your repository on the Hub. Used only if `push_to_hub=True`. Will default to the folder name if | |
not provided. | |
kwargs: | |
Additional key word arguments passed along to the [`~ModelHubMixin.push_to_hub`] method. | |
""" | |
save_directory = Path(save_directory) | |
save_directory.mkdir(parents=True, exist_ok=True) | |
# save model weights/files | |
torch.save(self.state_dict(), save_directory / "pytorch_model.bin") | |
# save config (if provided) | |
if config is None: | |
config = self.config | |
if config is not None: | |
if isinstance(config, argparse.Namespace): | |
config = vars(config) | |
(save_directory / "gliner_config.json").write_text(json.dumps(config, indent=2)) | |
# push to the Hub if required | |
if push_to_hub: | |
kwargs = push_to_hub_kwargs.copy() # soft-copy to avoid mutating input | |
if config is not None: # kwarg for `push_to_hub` | |
kwargs["config"] = config | |
if repo_id is None: | |
repo_id = save_directory.name # Defaults to `save_directory` name | |
return self.push_to_hub(repo_id=repo_id, **kwargs) | |
return None | |
def to(self, device): | |
super().to(device) | |
import flair | |
flair.device = device | |
return self | |