File size: 6,501 Bytes
6ed21b9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
#!/usr/bin/python3

import sys
import pickle
import torch
from huggingface_hub import hf_hub_download

from .Data import Data
from .RNNTagger import RNNTagger
from .CRFTagger import CRFTagger


###########################################################################
# main function
###########################################################################

class Args:
   def __init__(self, path_param, model_id, path_data, crf_beam_size, gpu, min_prob, print_probs) -> None:
      self.path_param = path_param
      self.model_id = model_id
      self.path_data = path_data
      self.crf_beam_size = crf_beam_size
      self.gpu = gpu
      self.min_prob = min_prob
      self.print_probs = print_probs

# if __name__ == "__main__":
def annotate(tokens, path_param='MHGTagger/tagger', model_id='nielklug/rnn_tagger', path_data='', crf_beam_size=10, gpu=-1, min_prob=-1.0, print_probs=True):

   # parser = argparse.ArgumentParser(description='Annotation program of the RNN-Tagger.')

   # parser.add_argument('path_param', type=str,
   #                     help='name of parameter file')
   # parser.add_argument('path_data', type=str,
   #                     help='name of the file with input data')
   # parser.add_argument('--crf_beam_size', type=int, default=10,
   #                     help='size of the CRF beam (if the system contains a CRF layer)')
   # parser.add_argument('--gpu', type=int, default=0,
   #                     help='selection of the GPU. The default is: 0 (CPU=-1)')
   # parser.add_argument("--min_prob", type=float, default=-1.0,
   #                     help="print all tags whose probability exceeds the probability of the best tag times this threshold")
   # parser.add_argument("--print_probs", action="store_true", default=False,
   #                     help="print the tag probabilities")

   args = Args(path_param, model_id, path_data, crf_beam_size, gpu, min_prob, print_probs)

   # Select the processing device
   if args.gpu >= 0:
      if not torch.cuda.is_available():
         print('No gpu available. Using cpu instead.', file=sys.stderr)
         args.gpu = -1
      else:
         if args.gpu >= torch.cuda.device_count():
            print('gpu '+str(args.gpu)+' not available. Using gpu 0 instead.', file=sys.stderr)
            args.gpu = 0
         torch.cuda.set_device(args.gpu)
   device = torch.device('cuda' if args.gpu >= 0 else 'cpu')

   # load parameters
   data  = Data(args.path_param+'.io')   # read the symbol mapping tables

   with open(args.path_param+'.hyper', 'rb') as file:
      hyper_params = pickle.load(file)
   model = CRFTagger(*hyper_params) if len(hyper_params)==10 \
           else RNNTagger(*hyper_params)
   
   model_file = hf_hub_download(repo_id=args.model_id, filename='tagger.rnn')
   model.load_state_dict(torch.load(model_file, 
                         map_location=torch.device('cpu')))

   model = model.to(device)

   if type(model) is CRFTagger:
      for optvar, option in zip((args.min_prob, args.print_probs),
                                ("min_prob","print_probs")):
         if optvar:
            print(f"Warning: Option --{option} is ignored because the model has a CRF output layer", file=sys.stderr)
   
   model.eval()
   with torch.no_grad():
      for i, words in enumerate(data.single_sentences(tokens)):
         # print(i, end='\r', file=sys.stderr, flush=True)
   
         # map words to numbers and create Torch variables
         fwd_charIDs, bwd_charIDs = data.words2charIDvec(words)
         fwd_charIDs = torch.LongTensor(fwd_charIDs).to(device)
         bwd_charIDs = torch.LongTensor(bwd_charIDs).to(device)
         
         words_all = []
         tagged = []
         probs_all = []
         # run the model
         if type(model) is RNNTagger:
            tagscores = model(fwd_charIDs, bwd_charIDs)
            if args.min_prob == -1.0:
               # only print the word and tag with the highest score
               tagIDs = tagscores.argmax(-1)
               tags = data.IDs2tags(tagIDs.to("cpu"))
               if not args.print_probs:
                  for word, tag in zip(words, tags):
                     # print(word, tag, sep="\t")
                     words_all.append(word)
                     tagged.append(tag)
               else:
                  # print probabilities as well
                  tagprobs = torch.nn.functional.softmax(tagscores, dim=-1)
                  # get the probabilities of the highest-scoring tags
                  probs = tagprobs[range(len(tagIDs)), tagIDs].to("cpu").tolist()
                  # print the result
                  for word, tag, prob in zip(words, tags, probs):
                     # print(word, tag, round(float(prob), 4), sep="\t")
                     words_all.append(word)
                     tagged.append(tag)
                     probs_all.append(round(float(prob), 4))
            else:
               # print the best tags for each word
               tagprobs = torch.nn.functional.softmax(tagscores, dim=-1)
               # get the most probable tag and its probability
               best_probs, _ = tagprobs.max(-1)
               # get all tags with a probability above best_prob * min_prob
               thresholds = best_probs * args.min_prob
               greaterflags = (tagprobs > thresholds.unsqueeze(1))
               for word, flags, probs in zip(words, greaterflags, tagprobs):
                  # get the IDs of the best tags
                  IDs = flags.nonzero()
                  # get the best tags and their probabilities
                  best_probs = probs[IDs].to("cpu")
                  best_tags = data.IDs2tags(IDs.to("cpu"))
                  # sort the tags by decreasing probability
                  sorted_list = sorted(zip(best_tags, best_probs), key=lambda x:-x[1])
                  best_tags, best_probs = zip(*sorted_list)
                  # generate the output
                  if args.print_probs:
                     # append the probabilities to the tags
                     best_tags = [f"{t} {float(p):.4f}" for t, p in zip(best_tags, best_probs)]
                  print(word, ' '.join(best_tags), sep="\t")
         elif type(model) is CRFTagger:
            tagIDs = model(fwd_charIDs, bwd_charIDs)
            tags = data.IDs2tags(tagIDs)
            for word, tag in zip(words, tags):
               print(word, tag, sep='\t')
         else:
            sys.exit('Error')
   
         return (words_all, tagged, probs_all)