| 'Model training for NLP' |
| from ..torch_core import * |
| from ..basic_train import * |
| from ..callbacks import * |
| from ..data_block import CategoryList |
| from ..basic_data import * |
| from ..datasets import * |
| from ..metrics import accuracy |
| from ..train import GradientClipping |
| from ..layers import * |
| from .models import * |
| from .transform import * |
| from .data import * |
|
|
| __all__ = ['RNNLearner', 'LanguageLearner', 'convert_weights', 'decode_spec_tokens', 'get_language_model', 'language_model_learner', |
| 'MultiBatchEncoder', 'get_text_classifier', 'text_classifier_learner', 'PoolingLinearClassifier'] |
|
|
| _model_meta = {AWD_LSTM: {'hid_name':'emb_sz', 'url':URLs.WT103_FWD, 'url_bwd':URLs.WT103_BWD, |
| 'config_lm':awd_lstm_lm_config, 'split_lm': awd_lstm_lm_split, |
| 'config_clas':awd_lstm_clas_config, 'split_clas': awd_lstm_clas_split}, |
| Transformer: {'hid_name':'d_model', 'url':URLs.OPENAI_TRANSFORMER, |
| 'config_lm':tfmer_lm_config, 'split_lm': tfmer_lm_split, |
| 'config_clas':tfmer_clas_config, 'split_clas': tfmer_clas_split}, |
| TransformerXL: {'hid_name':'d_model', |
| 'config_lm':tfmerXL_lm_config, 'split_lm': tfmerXL_lm_split, |
| 'config_clas':tfmerXL_clas_config, 'split_clas': tfmerXL_clas_split}} |
|
|
| def convert_weights(wgts:Weights, stoi_wgts:Dict[str,int], itos_new:Collection[str]) -> Weights: |
| "Convert the model `wgts` to go with a new vocabulary." |
| dec_bias, enc_wgts = wgts.get('1.decoder.bias', None), wgts['0.encoder.weight'] |
| wgts_m = enc_wgts.mean(0) |
| if dec_bias is not None: bias_m = dec_bias.mean(0) |
| new_w = enc_wgts.new_zeros((len(itos_new),enc_wgts.size(1))).zero_() |
| if dec_bias is not None: new_b = dec_bias.new_zeros((len(itos_new),)).zero_() |
| for i,w in enumerate(itos_new): |
| r = stoi_wgts[w] if w in stoi_wgts else -1 |
| new_w[i] = enc_wgts[r] if r>=0 else wgts_m |
| if dec_bias is not None: new_b[i] = dec_bias[r] if r>=0 else bias_m |
| wgts['0.encoder.weight'] = new_w |
| if '0.encoder_dp.emb.weight' in wgts: wgts['0.encoder_dp.emb.weight'] = new_w.clone() |
| wgts['1.decoder.weight'] = new_w.clone() |
| if dec_bias is not None: wgts['1.decoder.bias'] = new_b |
| return wgts |
|
|
| class RNNLearner(Learner): |
| "Basic class for a `Learner` in NLP." |
| def __init__(self, data:DataBunch, model:nn.Module, split_func:OptSplitFunc=None, clip:float=None, |
| alpha:float=2., beta:float=1., metrics=None, **learn_kwargs): |
| is_class = (hasattr(data.train_ds, 'y') and (isinstance(data.train_ds.y, CategoryList) or |
| isinstance(data.train_ds.y, LMLabelList))) |
| metrics = ifnone(metrics, ([accuracy] if is_class else [])) |
| super().__init__(data, model, metrics=metrics, **learn_kwargs) |
| self.callbacks.append(RNNTrainer(self, alpha=alpha, beta=beta)) |
| if clip: self.callback_fns.append(partial(GradientClipping, clip=clip)) |
| if split_func: self.split(split_func) |
|
|
| def save_encoder(self, name:str): |
| "Save the encoder to `name` inside the model directory." |
| if is_pathlike(name): self._test_writeable_path() |
| encoder = get_model(self.model)[0] |
| if hasattr(encoder, 'module'): encoder = encoder.module |
| torch.save(encoder.state_dict(), self.path/self.model_dir/f'{name}.pth') |
|
|
| def load_encoder(self, name:str, device:torch.device=None): |
| "Load the encoder `name` from the model directory." |
| encoder = get_model(self.model)[0] |
| if device is None: device = self.data.device |
| if hasattr(encoder, 'module'): encoder = encoder.module |
| encoder.load_state_dict(torch.load(self.path/self.model_dir/f'{name}.pth', map_location=device)) |
| self.freeze() |
|
|
| def load_pretrained(self, wgts_fname:str, itos_fname:str, strict:bool=True): |
| "Load a pretrained model and adapts it to the data vocabulary." |
| old_itos = pickle.load(open(itos_fname, 'rb')) |
| old_stoi = {v:k for k,v in enumerate(old_itos)} |
| wgts = torch.load(wgts_fname, map_location=lambda storage, loc: storage) |
| if 'model' in wgts: wgts = wgts['model'] |
| wgts = convert_weights(wgts, old_stoi, self.data.train_ds.vocab.itos) |
| self.model.load_state_dict(wgts, strict=strict) |
|
|
| def get_preds(self, ds_type:DatasetType=DatasetType.Valid, activ:nn.Module=None, with_loss:bool=False, n_batch:Optional[int]=None, |
| pbar:Optional[PBar]=None, ordered:bool=False) -> List[Tensor]: |
| "Return predictions and targets on the valid, train, or test set, depending on `ds_type`." |
| self.model.reset() |
| if ordered: np.random.seed(42) |
| preds = super().get_preds(ds_type=ds_type, activ=activ, with_loss=with_loss, n_batch=n_batch, pbar=pbar) |
| if ordered and hasattr(self.dl(ds_type), 'sampler'): |
| np.random.seed(42) |
| sampler = [i for i in self.dl(ds_type).sampler] |
| reverse_sampler = np.argsort(sampler) |
| preds = [p[reverse_sampler] for p in preds] |
| return(preds) |
|
|
| def decode_spec_tokens(tokens): |
| new_toks,rule,arg = [],None,None |
| for t in tokens: |
| if t in [TK_MAJ, TK_UP, TK_REP, TK_WREP]: rule = t |
| elif rule is None: new_toks.append(t) |
| elif rule == TK_MAJ: |
| new_toks.append(t[:1].upper() + t[1:].lower()) |
| rule = None |
| elif rule == TK_UP: |
| new_toks.append(t.upper()) |
| rule = None |
| elif arg is None: |
| try: arg = int(t) |
| except: rule = None |
| else: |
| if rule == TK_REP: new_toks.append(t * arg) |
| else: new_toks += [t] * arg |
| return new_toks |
|
|
| class LanguageLearner(RNNLearner): |
| "Subclass of RNNLearner for predictions." |
|
|
| def predict(self, text:str, n_words:int=1, no_unk:bool=True, temperature:float=1., min_p:float=None, sep:str=' ', |
| decoder=decode_spec_tokens): |
| "Return the `n_words` that come after `text`." |
| ds = self.data.single_dl.dataset |
| self.model.reset() |
| xb,yb = self.data.one_item(text) |
| new_idx = [] |
| for _ in range(n_words): |
| res = self.pred_batch(batch=(xb,yb))[0][-1] |
| |
| if no_unk: res[self.data.vocab.stoi[UNK]] = 0. |
| if min_p is not None: |
| if (res >= min_p).float().sum() == 0: |
| warn(f"There is no item with probability >= {min_p}, try a lower value.") |
| else: res[res < min_p] = 0. |
| if temperature != 1.: res.pow_(1 / temperature) |
| idx = torch.multinomial(res, 1).item() |
| new_idx.append(idx) |
| xb = xb.new_tensor([idx])[None] |
| return text + sep + sep.join(decoder(self.data.vocab.textify(new_idx, sep=None))) |
|
|
| def beam_search(self, text:str, n_words:int, no_unk:bool=True, top_k:int=10, beam_sz:int=1000, temperature:float=1., |
| sep:str=' ', decoder=decode_spec_tokens): |
| "Return the `n_words` that come after `text` using beam search." |
| ds = self.data.single_dl.dataset |
| self.model.reset() |
| self.model.eval() |
| xb, yb = self.data.one_item(text) |
| nodes = None |
| nodes = xb.clone() |
| scores = xb.new_zeros(1).float() |
| with torch.no_grad(): |
| for k in progress_bar(range(n_words), leave=False): |
| out = F.log_softmax(self.model(xb)[0][:,-1], dim=-1) |
| if no_unk: out[:,self.data.vocab.stoi[UNK]] = -float('Inf') |
| values, indices = out.topk(top_k, dim=-1) |
| scores = (-values + scores[:,None]).view(-1) |
| indices_idx = torch.arange(0,nodes.size(0))[:,None].expand(nodes.size(0), top_k).contiguous().view(-1) |
| sort_idx = scores.argsort()[:beam_sz] |
| scores = scores[sort_idx] |
| nodes = torch.cat([nodes[:,None].expand(nodes.size(0),top_k,nodes.size(1)), |
| indices[:,:,None].expand(nodes.size(0),top_k,1),], dim=2) |
| nodes = nodes.view(-1, nodes.size(2))[sort_idx] |
| self.model[0].select_hidden(indices_idx[sort_idx]) |
| xb = nodes[:,-1][:,None] |
| if temperature != 1.: scores.div_(temperature) |
| node_idx = torch.multinomial(torch.exp(-scores), 1).item() |
| return text + sep + sep.join(decoder(self.data.vocab.textify([i.item() for i in nodes[node_idx][1:] ], sep=None))) |
|
|
| def show_results(self, ds_type=DatasetType.Valid, rows:int=5, max_len:int=20): |
| from IPython.display import display, HTML |
| "Show `rows` result of predictions on `ds_type` dataset." |
| ds = self.dl(ds_type).dataset |
| x,y = self.data.one_batch(ds_type, detach=False, denorm=False) |
| preds = self.pred_batch(batch=(x,y)) |
| y = y.view(*x.size()) |
| z = preds.view(*x.size(),-1).argmax(dim=2) |
| xs = [ds.x.reconstruct(grab_idx(x, i)) for i in range(rows)] |
| ys = [ds.x.reconstruct(grab_idx(y, i)) for i in range(rows)] |
| zs = [ds.x.reconstruct(grab_idx(z, i)) for i in range(rows)] |
| items,names = [],['text', 'target', 'pred'] |
| for i, (x,y,z) in enumerate(zip(xs,ys,zs)): |
| txt_x = ' '.join(x.text.split(' ')[:max_len]) |
| txt_y = ' '.join(y.text.split(' ')[max_len-1:2*max_len-1]) |
| txt_z = ' '.join(z.text.split(' ')[max_len-1:2*max_len-1]) |
| items.append([txt_x, txt_y, txt_z]) |
| items = np.array(items) |
| df = pd.DataFrame({n:items[:,i] for i,n in enumerate(names)}, columns=names) |
| with pd.option_context('display.max_colwidth', -1): |
| display(HTML(df.to_html(index=False))) |
|
|
| def get_language_model(arch:Callable, vocab_sz:int, config:dict=None, drop_mult:float=1.): |
| "Create a language model from `arch` and its `config`, maybe `pretrained`." |
| meta = _model_meta[arch] |
| config = ifnone(config, meta['config_lm']).copy() |
| for k in config.keys(): |
| if k.endswith('_p'): config[k] *= drop_mult |
| tie_weights,output_p,out_bias = map(config.pop, ['tie_weights', 'output_p', 'out_bias']) |
| init = config.pop('init') if 'init' in config else None |
| encoder = arch(vocab_sz, **config) |
| enc = encoder.encoder if tie_weights else None |
| decoder = LinearDecoder(vocab_sz, config[meta['hid_name']], output_p, tie_encoder=enc, bias=out_bias) |
| model = SequentialRNN(encoder, decoder) |
| return model if init is None else model.apply(init) |
|
|
| def language_model_learner(data:DataBunch, arch, config:dict=None, drop_mult:float=1., pretrained:bool=True, |
| pretrained_fnames:OptStrTuple=None, **learn_kwargs) -> 'LanguageLearner': |
| "Create a `Learner` with a language model from `data` and `arch`." |
| model = get_language_model(arch, len(data.vocab.itos), config=config, drop_mult=drop_mult) |
| meta = _model_meta[arch] |
| learn = LanguageLearner(data, model, split_func=meta['split_lm'], **learn_kwargs) |
| url = 'url_bwd' if data.backwards else 'url' |
| if pretrained or pretrained_fnames: |
| if pretrained_fnames is not None: |
| fnames = [learn.path/learn.model_dir/f'{fn}.{ext}' for fn,ext in zip(pretrained_fnames, ['pth', 'pkl'])] |
| else: |
| if url not in meta: |
| warn("There are no pretrained weights for that architecture yet!") |
| return learn |
| model_path = untar_data(meta[url] , data=False) |
| fnames = [list(model_path.glob(f'*.{ext}'))[0] for ext in ['pth', 'pkl']] |
| learn.load_pretrained(*fnames) |
| learn.freeze() |
| return learn |
|
|
| def masked_concat_pool(outputs, mask): |
| "Pool MultiBatchEncoder outputs into one vector [last_hidden, max_pool, avg_pool]." |
| output = outputs[-1] |
| avg_pool = output.masked_fill(mask[:, :, None], 0).mean(dim=1) |
| avg_pool *= output.size(1) / (output.size(1)-mask.type(avg_pool.dtype).sum(dim=1))[:,None] |
| max_pool = output.masked_fill(mask[:,:,None], -float('inf')).max(dim=1)[0] |
| x = torch.cat([output[:,-1], max_pool, avg_pool], 1) |
| return x |
|
|
| class PoolingLinearClassifier(Module): |
| "Create a linear classifier with pooling." |
| def __init__(self, layers:Collection[int], drops:Collection[float]): |
| mod_layers = [] |
| if len(drops) != len(layers)-1: raise ValueError("Number of layers and dropout values do not match.") |
| activs = [nn.ReLU(inplace=True)] * (len(layers) - 2) + [None] |
| for n_in, n_out, p, actn in zip(layers[:-1], layers[1:], drops, activs): |
| mod_layers += bn_drop_lin(n_in, n_out, p=p, actn=actn) |
| self.layers = nn.Sequential(*mod_layers) |
|
|
| def forward(self, input:Tuple[Tensor,Tensor, Tensor])->Tuple[Tensor,Tensor,Tensor]: |
| raw_outputs,outputs,mask = input |
| x = masked_concat_pool(outputs, mask) |
| x = self.layers(x) |
| return x, raw_outputs, outputs |
|
|
| class MultiBatchEncoder(Module): |
| "Create an encoder over `module` that can process a full sentence." |
| def __init__(self, bptt:int, max_len:int, module:nn.Module, pad_idx:int=1): |
| self.max_len,self.bptt,self.module,self.pad_idx = max_len,bptt,module,pad_idx |
|
|
| def concat(self, arrs:Collection[Tensor])->Tensor: |
| "Concatenate the `arrs` along the batch dimension." |
| return [torch.cat([l[si] for l in arrs], dim=1) for si in range_of(arrs[0])] |
|
|
| def reset(self): |
| if hasattr(self.module, 'reset'): self.module.reset() |
|
|
| def forward(self, input:LongTensor)->Tuple[Tensor,Tensor]: |
| bs,sl = input.size() |
| self.reset() |
| raw_outputs,outputs,masks = [],[],[] |
| for i in range(0, sl, self.bptt): |
| r, o = self.module(input[:,i: min(i+self.bptt, sl)]) |
| if i>(sl-self.max_len): |
| masks.append(input[:,i: min(i+self.bptt, sl)] == self.pad_idx) |
| raw_outputs.append(r) |
| outputs.append(o) |
| return self.concat(raw_outputs),self.concat(outputs),torch.cat(masks,dim=1) |
|
|
| def get_text_classifier(arch:Callable, vocab_sz:int, n_class:int, bptt:int=70, max_len:int=20*70, config:dict=None, |
| drop_mult:float=1., lin_ftrs:Collection[int]=None, ps:Collection[float]=None, |
| pad_idx:int=1) -> nn.Module: |
| "Create a text classifier from `arch` and its `config`, maybe `pretrained`." |
| meta = _model_meta[arch] |
| config = ifnone(config, meta['config_clas']).copy() |
| for k in config.keys(): |
| if k.endswith('_p'): config[k] *= drop_mult |
| if lin_ftrs is None: lin_ftrs = [50] |
| if ps is None: ps = [0.1]*len(lin_ftrs) |
| layers = [config[meta['hid_name']] * 3] + lin_ftrs + [n_class] |
| ps = [config.pop('output_p')] + ps |
| init = config.pop('init') if 'init' in config else None |
| encoder = MultiBatchEncoder(bptt, max_len, arch(vocab_sz, **config), pad_idx=pad_idx) |
| model = SequentialRNN(encoder, PoolingLinearClassifier(layers, ps)) |
| return model if init is None else model.apply(init) |
|
|
| def text_classifier_learner(data:DataBunch, arch:Callable, bptt:int=70, max_len:int=70*20, config:dict=None, |
| pretrained:bool=True, drop_mult:float=1., lin_ftrs:Collection[int]=None, |
| ps:Collection[float]=None, **learn_kwargs) -> 'TextClassifierLearner': |
| "Create a `Learner` with a text classifier from `data` and `arch`." |
| model = get_text_classifier(arch, len(data.vocab.itos), data.c, bptt=bptt, max_len=max_len, |
| config=config, drop_mult=drop_mult, lin_ftrs=lin_ftrs, ps=ps) |
| meta = _model_meta[arch] |
| learn = RNNLearner(data, model, split_func=meta['split_clas'], **learn_kwargs) |
| if pretrained: |
| if 'url' not in meta: |
| warn("There are no pretrained weights for that architecture yet!") |
| return learn |
| model_path = untar_data(meta['url'], data=False) |
| fnames = [list(model_path.glob(f'*.{ext}'))[0] for ext in ['pth', 'pkl']] |
| learn.load_pretrained(*fnames, strict=False) |
| learn.freeze() |
| return learn |
|
|