mhg-parsing / MHGTagger /rnn_annotate.py
nielklug's picture
init
6ed21b9
raw
history blame contribute delete
No virus
6.5 kB
#!/usr/bin/python3
import sys
import pickle
import torch
from huggingface_hub import hf_hub_download
from .Data import Data
from .RNNTagger import RNNTagger
from .CRFTagger import CRFTagger
###########################################################################
# main function
###########################################################################
class Args:
def __init__(self, path_param, model_id, path_data, crf_beam_size, gpu, min_prob, print_probs) -> None:
self.path_param = path_param
self.model_id = model_id
self.path_data = path_data
self.crf_beam_size = crf_beam_size
self.gpu = gpu
self.min_prob = min_prob
self.print_probs = print_probs
# if __name__ == "__main__":
def annotate(tokens, path_param='MHGTagger/tagger', model_id='nielklug/rnn_tagger', path_data='', crf_beam_size=10, gpu=-1, min_prob=-1.0, print_probs=True):
# parser = argparse.ArgumentParser(description='Annotation program of the RNN-Tagger.')
# parser.add_argument('path_param', type=str,
# help='name of parameter file')
# parser.add_argument('path_data', type=str,
# help='name of the file with input data')
# parser.add_argument('--crf_beam_size', type=int, default=10,
# help='size of the CRF beam (if the system contains a CRF layer)')
# parser.add_argument('--gpu', type=int, default=0,
# help='selection of the GPU. The default is: 0 (CPU=-1)')
# parser.add_argument("--min_prob", type=float, default=-1.0,
# help="print all tags whose probability exceeds the probability of the best tag times this threshold")
# parser.add_argument("--print_probs", action="store_true", default=False,
# help="print the tag probabilities")
args = Args(path_param, model_id, path_data, crf_beam_size, gpu, min_prob, print_probs)
# Select the processing device
if args.gpu >= 0:
if not torch.cuda.is_available():
print('No gpu available. Using cpu instead.', file=sys.stderr)
args.gpu = -1
else:
if args.gpu >= torch.cuda.device_count():
print('gpu '+str(args.gpu)+' not available. Using gpu 0 instead.', file=sys.stderr)
args.gpu = 0
torch.cuda.set_device(args.gpu)
device = torch.device('cuda' if args.gpu >= 0 else 'cpu')
# load parameters
data = Data(args.path_param+'.io') # read the symbol mapping tables
with open(args.path_param+'.hyper', 'rb') as file:
hyper_params = pickle.load(file)
model = CRFTagger(*hyper_params) if len(hyper_params)==10 \
else RNNTagger(*hyper_params)
model_file = hf_hub_download(repo_id=args.model_id, filename='tagger.rnn')
model.load_state_dict(torch.load(model_file,
map_location=torch.device('cpu')))
model = model.to(device)
if type(model) is CRFTagger:
for optvar, option in zip((args.min_prob, args.print_probs),
("min_prob","print_probs")):
if optvar:
print(f"Warning: Option --{option} is ignored because the model has a CRF output layer", file=sys.stderr)
model.eval()
with torch.no_grad():
for i, words in enumerate(data.single_sentences(tokens)):
# print(i, end='\r', file=sys.stderr, flush=True)
# map words to numbers and create Torch variables
fwd_charIDs, bwd_charIDs = data.words2charIDvec(words)
fwd_charIDs = torch.LongTensor(fwd_charIDs).to(device)
bwd_charIDs = torch.LongTensor(bwd_charIDs).to(device)
words_all = []
tagged = []
probs_all = []
# run the model
if type(model) is RNNTagger:
tagscores = model(fwd_charIDs, bwd_charIDs)
if args.min_prob == -1.0:
# only print the word and tag with the highest score
tagIDs = tagscores.argmax(-1)
tags = data.IDs2tags(tagIDs.to("cpu"))
if not args.print_probs:
for word, tag in zip(words, tags):
# print(word, tag, sep="\t")
words_all.append(word)
tagged.append(tag)
else:
# print probabilities as well
tagprobs = torch.nn.functional.softmax(tagscores, dim=-1)
# get the probabilities of the highest-scoring tags
probs = tagprobs[range(len(tagIDs)), tagIDs].to("cpu").tolist()
# print the result
for word, tag, prob in zip(words, tags, probs):
# print(word, tag, round(float(prob), 4), sep="\t")
words_all.append(word)
tagged.append(tag)
probs_all.append(round(float(prob), 4))
else:
# print the best tags for each word
tagprobs = torch.nn.functional.softmax(tagscores, dim=-1)
# get the most probable tag and its probability
best_probs, _ = tagprobs.max(-1)
# get all tags with a probability above best_prob * min_prob
thresholds = best_probs * args.min_prob
greaterflags = (tagprobs > thresholds.unsqueeze(1))
for word, flags, probs in zip(words, greaterflags, tagprobs):
# get the IDs of the best tags
IDs = flags.nonzero()
# get the best tags and their probabilities
best_probs = probs[IDs].to("cpu")
best_tags = data.IDs2tags(IDs.to("cpu"))
# sort the tags by decreasing probability
sorted_list = sorted(zip(best_tags, best_probs), key=lambda x:-x[1])
best_tags, best_probs = zip(*sorted_list)
# generate the output
if args.print_probs:
# append the probabilities to the tags
best_tags = [f"{t} {float(p):.4f}" for t, p in zip(best_tags, best_probs)]
print(word, ' '.join(best_tags), sep="\t")
elif type(model) is CRFTagger:
tagIDs = model(fwd_charIDs, bwd_charIDs)
tags = data.IDs2tags(tagIDs)
for word, tag in zip(words, tags):
print(word, tag, sep='\t')
else:
sys.exit('Error')
return (words_all, tagged, probs_all)