AmitGarage commited on
Commit
a25ba4b
1 Parent(s): 945abe6

Update scripts/torch_ner_pipe.py

Browse files
Files changed (1) hide show
  1. scripts/torch_ner_pipe.py +2 -24
scripts/torch_ner_pipe.py CHANGED
@@ -34,7 +34,6 @@ def set_torch_dropout_rate(model: Model, dropout_rate: float):
34
  model (Model): Thinc Model (with PyTorch sub-modules)
35
  dropout_rate (float): Dropout rate
36
  """
37
- #print("Entered set_torch_dropout_rate - ")
38
  set_dropout_rate(model, dropout_rate)
39
  func = model.get_ref("torch_model").attrs["set_dropout_rate"]
40
  func(dropout_rate)
@@ -78,7 +77,6 @@ def make_torch_entity_recognizer(nlp: Language, name: str, model: Model):
78
  in size, and be normalized as probabilities (all scores between 0 and 1,
79
  with the rows summing to 1).
80
  """
81
- #print("Entered make_torch_entity_recognizer - ")
82
  return TorchEntityRecognizer(nlp.vocab, model, name)
83
 
84
 
@@ -92,23 +90,17 @@ class TorchEntityRecognizer(TrainablePipe):
92
  name (str): The component instance name, used to add entries to the
93
  losses during training.
94
  """
95
- #print("Entered pipe TorchEntityRecognizer.__init__ - ")
96
  self.vocab = vocab
97
  self.model = model
98
  self.name = name
99
  cfg = {"labels": []}
100
  self.cfg = dict(sorted(cfg.items()))
101
- #print(self.vocab,self.model,self.name,self.cfg)
102
- #print(self.model.layers[0].ref_names)
103
- #print(self.model.layers[1].ref_names)
104
- #print("Completed pipe TorchEntityRecognizer.__init__ - ")
105
 
106
  @property
107
  def labels(self) -> Tuple[str, ...]:
108
  """The labels currently added to the component.
109
  RETURNS (Tuple[str]): The labels.
110
  """
111
- ##print("Entered TorchEntityRecognizer.labels - ")
112
  labels = ["O"]
113
  for label in self.cfg["labels"]:
114
  for iob in ["B", "I"]:
@@ -120,7 +112,6 @@ class TorchEntityRecognizer(TrainablePipe):
120
  docs (Iterable[Doc]): The documents to predict.
121
  RETURNS: The models prediction for each document.
122
  """
123
- #print("Entered pipe TorchEntityRecognizer.predict - ")
124
  if not any(len(doc) for doc in docs):
125
  # Handle cases where there are no tokens in any docs.
126
  n_labels = len(self.labels)
@@ -144,7 +135,6 @@ class TorchEntityRecognizer(TrainablePipe):
144
  docs (Iterable[Doc]): The documents to modify.
145
  preds (Iterable[Ints1d]): The IDs to set, produced by TorchEntityRecognizer.predict.
146
  """
147
- #print("Entered pipe TorchEntityRecognizer.set_annotations - ")
148
  if isinstance(docs, Doc):
149
  docs = [docs]
150
  for doc, tag_ids in zip(docs, preds):
@@ -176,7 +166,6 @@ class TorchEntityRecognizer(TrainablePipe):
176
  Updated using the component name as the key.
177
  RETURNS (Dict[str, float]): The updated losses dictionary.
178
  """
179
- #print("Entered pipe TorchEntityRecognizer.update - ")
180
  if losses is None:
181
  losses = {}
182
  losses.setdefault(self.name, 0.0)
@@ -208,7 +197,6 @@ class TorchEntityRecognizer(TrainablePipe):
208
  scores: Scores representing the model's predictions.
209
  RETURNS (Tuple[float, float]): The loss and the gradient.
210
  """
211
- #print("Entered pipe TorchEntityRecognizer.get_loss - ")
212
  validate_examples(examples, "TorchEntityRecognizer.get_loss")
213
  loss_func = SequenceCategoricalCrossentropy(names=self.labels, normalize=False)
214
  truths = []
@@ -238,7 +226,6 @@ class TorchEntityRecognizer(TrainablePipe):
238
  `init labels` command. If no labels are provided, the get_examples
239
  callback is used to extract the labels from the data.
240
  """
241
- #print("Entered pipe TorchEntityRecognizer.initialize - ")
242
  validate_get_examples(get_examples, "TorchEntityRecognizer.initialize")
243
  if labels is not None:
244
  for tag in labels:
@@ -257,24 +244,16 @@ class TorchEntityRecognizer(TrainablePipe):
257
 
258
  self._require_labels()
259
  assert len(doc_sample) > 0, Errors.E923.format(name=self.name)
260
- #print(nlp.config["components"][self.name]["model"]["nO"])
261
- ##print(nlp.config["components"][self.name]["model"]["nI"])
262
  self.model.initialize(X=doc_sample, Y=self.labels)
263
- #print("self.model.initialize exit")
264
- #print(self.model.name)
265
- #print(self.model.layers[0].ref_names)
266
- #print(self.model.layers[1].ref_names)
267
- #print(self.name)
268
  nlp.config["components"][self.name]["model"]["nO"] = len(self.labels)
269
- #nlp.config["components"][self.name]["model"]["nI"] = 768
270
- #print(nlp.config["components"][self.name]["model"])
271
 
272
  def add_label(self, label: str) -> int:
273
  """Add a new label to the pipe.
274
  label (str): The label to add.
275
  RETURNS (int): 0 if label is already present, otherwise 1.
276
  """
277
- #print("Entered pipe TorchEntityRecognizer.add_label - ")
278
  if not isinstance(label, str):
279
  raise ValueError(Errors.E187)
280
  if label in self.labels:
@@ -289,6 +268,5 @@ class TorchEntityRecognizer(TrainablePipe):
289
  examples (Iterable[Example]): The examples to score.
290
  RETURNS (Dict[str, Any]): The NER precision, recall and f-scores.
291
  """
292
- #print("Entered pipe TorchEntityRecognizer.score - ")
293
  validate_examples(examples, "TorchEntityRecognizer.score")
294
  return get_ner_prf(examples)
 
34
  model (Model): Thinc Model (with PyTorch sub-modules)
35
  dropout_rate (float): Dropout rate
36
  """
 
37
  set_dropout_rate(model, dropout_rate)
38
  func = model.get_ref("torch_model").attrs["set_dropout_rate"]
39
  func(dropout_rate)
 
77
  in size, and be normalized as probabilities (all scores between 0 and 1,
78
  with the rows summing to 1).
79
  """
 
80
  return TorchEntityRecognizer(nlp.vocab, model, name)
81
 
82
 
 
90
  name (str): The component instance name, used to add entries to the
91
  losses during training.
92
  """
 
93
  self.vocab = vocab
94
  self.model = model
95
  self.name = name
96
  cfg = {"labels": []}
97
  self.cfg = dict(sorted(cfg.items()))
 
 
 
 
98
 
99
  @property
100
  def labels(self) -> Tuple[str, ...]:
101
  """The labels currently added to the component.
102
  RETURNS (Tuple[str]): The labels.
103
  """
 
104
  labels = ["O"]
105
  for label in self.cfg["labels"]:
106
  for iob in ["B", "I"]:
 
112
  docs (Iterable[Doc]): The documents to predict.
113
  RETURNS: The models prediction for each document.
114
  """
 
115
  if not any(len(doc) for doc in docs):
116
  # Handle cases where there are no tokens in any docs.
117
  n_labels = len(self.labels)
 
135
  docs (Iterable[Doc]): The documents to modify.
136
  preds (Iterable[Ints1d]): The IDs to set, produced by TorchEntityRecognizer.predict.
137
  """
 
138
  if isinstance(docs, Doc):
139
  docs = [docs]
140
  for doc, tag_ids in zip(docs, preds):
 
166
  Updated using the component name as the key.
167
  RETURNS (Dict[str, float]): The updated losses dictionary.
168
  """
 
169
  if losses is None:
170
  losses = {}
171
  losses.setdefault(self.name, 0.0)
 
197
  scores: Scores representing the model's predictions.
198
  RETURNS (Tuple[float, float]): The loss and the gradient.
199
  """
 
200
  validate_examples(examples, "TorchEntityRecognizer.get_loss")
201
  loss_func = SequenceCategoricalCrossentropy(names=self.labels, normalize=False)
202
  truths = []
 
226
  `init labels` command. If no labels are provided, the get_examples
227
  callback is used to extract the labels from the data.
228
  """
 
229
  validate_get_examples(get_examples, "TorchEntityRecognizer.initialize")
230
  if labels is not None:
231
  for tag in labels:
 
244
 
245
  self._require_labels()
246
  assert len(doc_sample) > 0, Errors.E923.format(name=self.name)
 
 
247
  self.model.initialize(X=doc_sample, Y=self.labels)
 
 
 
 
 
248
  nlp.config["components"][self.name]["model"]["nO"] = len(self.labels)
249
+ if self.model.layers[0].maybe_get_ref("listener") != None :
250
+ nlp.config["components"][self.name]["model"]["width"] = self.model.layers[0].maybe_get_ref("listener").maybe_get_dim("nO")
251
 
252
  def add_label(self, label: str) -> int:
253
  """Add a new label to the pipe.
254
  label (str): The label to add.
255
  RETURNS (int): 0 if label is already present, otherwise 1.
256
  """
 
257
  if not isinstance(label, str):
258
  raise ValueError(Errors.E187)
259
  if label in self.labels:
 
268
  examples (Iterable[Example]): The examples to score.
269
  RETURNS (Dict[str, Any]): The NER precision, recall and f-scores.
270
  """
 
271
  validate_examples(examples, "TorchEntityRecognizer.score")
272
  return get_ner_prf(examples)