Source code for deeppavlov.models.torch_bert.torch_transformers_sequence_tagger

# Copyright 2019 Neural Networks and Deep Learning lab, MIPT
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from logging import getLogger
from pathlib import Path
from typing import List, Union, Dict, Optional, Tuple

import numpy as np
import torch
from transformers import AutoModelForTokenClassification, AutoConfig

from deeppavlov.core.commands.utils import expand_path
from deeppavlov.core.common.errors import ConfigError
from deeppavlov.core.common.registry import register
from deeppavlov.core.models.torch_model import TorchModel
from deeppavlov.models.torch_bert.crf import CRF

log = getLogger(__name__)


def token_from_subtoken(units: torch.Tensor, mask: torch.Tensor) -> torch.Tensor:
    """ Assemble token level units from subtoken level units

    Args:
        units: torch.Tensor of shape [batch_size, SUBTOKEN_seq_length, n_features]
        mask: mask of token beginnings. For example: for tokens

                [[``[CLS]`` ``My``, ``capybara``, ``[SEP]``],
                [``[CLS]`` ``Your``, ``aar``, ``##dvark``, ``is``, ``awesome``, ``[SEP]``]]

            the mask will be

                [[0, 1, 1, 0, 0, 0, 0],
                [0, 1, 1, 0, 1, 1, 0]]

    Returns:
        word_level_units: Units assembled from ones in the mask. For the
            example above this units will correspond to the following

                [[``My``, ``capybara``],
                [``Your`, ``aar``, ``is``, ``awesome``,]]

            the shape of this tensor will be [batch_size, TOKEN_seq_length, n_features]
    """
    shape = units.size()
    batch_size = shape[0]
    nf = shape[2]
    nf_int = units.size()[-1]

    token_seq_lengths = torch.sum(mask, 1).to(torch.int64)

    n_words = torch.sum(token_seq_lengths)

    max_token_seq_len = torch.max(token_seq_lengths)

    idxs = torch.stack(torch.nonzero(mask, as_tuple=True), dim=1)

    sample_ids_in_batch = torch.nn.functional.pad(input=idxs[:, 0], pad=[1, 0])

    a = torch.logical_not(torch.eq(sample_ids_in_batch[1:], sample_ids_in_batch[:-1]).to(torch.int64))

    q = a * torch.arange(n_words).to(torch.int64)
    count_to_substract = torch.nn.functional.pad(torch.masked_select(q, q.to(torch.bool)), [1, 0])

    new_word_indices = torch.arange(n_words).to(torch.int64) - torch.gather(
        count_to_substract, dim=0, index=torch.cumsum(a, 0))

    n_total_word_elements = (batch_size * max_token_seq_len).to(torch.int32)
    word_indices_flat = (idxs[:, 0] * max_token_seq_len + new_word_indices).to(torch.int64)
    x_mask = torch.sum(torch.nn.functional.one_hot(word_indices_flat, n_total_word_elements), 0)
    x_mask = x_mask.to(torch.bool)

    full_range = torch.arange(batch_size * max_token_seq_len).to(torch.int64)
    nonword_indices_flat = torch.masked_select(full_range, torch.logical_not(x_mask))

    def gather_nd(params, indices):
        assert type(indices) == torch.Tensor
        return params[indices.transpose(0, 1).long().numpy().tolist()]

    elements = gather_nd(units, idxs)

    sh = tuple(torch.stack([torch.sum(max_token_seq_len - token_seq_lengths), torch.tensor(nf)], 0).numpy())
    paddings = torch.zeros(sh, dtype=torch.float64)

    def dynamic_stitch(indices, data):
        # https://discuss.pytorch.org/t/equivalent-of-tf-dynamic-partition/53735/2
        n = sum(idx.numel() for idx in indices)
        res = [None] * n
        for i, data_ in enumerate(data):
            idx = indices[i].view(-1)
            if idx.numel() > 0:
                d = data_.view(idx.numel(), -1)
                k = 0
                for idx_ in idx:
                    res[idx_] = d[k].to(torch.float64)
                    k += 1
        return res

    tensor_flat = torch.stack(dynamic_stitch([word_indices_flat, nonword_indices_flat], [elements, paddings]))

    tensor = torch.reshape(tensor_flat, (batch_size, max_token_seq_len.item(), nf_int))

    return tensor


def token_labels_to_subtoken_labels(labels, y_mask, input_mask):
    subtoken_labels = []
    labels_ind = 0
    n_tokens_with_special = int(np.sum(input_mask))

    for el in y_mask[1:n_tokens_with_special - 1]:
        if el == 1:
            subtoken_labels += [labels[labels_ind]]
            labels_ind += 1
        else:
            subtoken_labels += [labels[labels_ind - 1]]

    subtoken_labels = [0] + subtoken_labels + [0] * (len(input_mask) - n_tokens_with_special + 1)
    return subtoken_labels


[docs]@register('torch_transformers_sequence_tagger') class TorchTransformersSequenceTagger(TorchModel): """Transformer-based model on PyTorch for text tagging. It predicts a label for every token (not subtoken) in the text. You can use it for sequence labeling tasks, such as morphological tagging or named entity recognition. Args: n_tags: number of distinct tags pretrained_bert: pretrained Bert checkpoint path or key title (e.g. "bert-base-uncased") bert_config_file: path to Bert configuration file, or None, if `pretrained_bert` is a string name attention_probs_keep_prob: keep_prob for Bert self-attention layers hidden_keep_prob: keep_prob for Bert hidden layers use_crf: whether to use Conditional Ramdom Field to decode tags """ def __init__(self, n_tags: int, pretrained_bert: str, bert_config_file: Optional[str] = None, attention_probs_keep_prob: Optional[float] = None, hidden_keep_prob: Optional[float] = None, use_crf: bool = False, **kwargs) -> None: if pretrained_bert: config = AutoConfig.from_pretrained(pretrained_bert, num_labels=n_tags, output_attentions=False, output_hidden_states=False) model = AutoModelForTokenClassification.from_pretrained(pretrained_bert, config=config) elif bert_config_file and Path(bert_config_file).is_file(): bert_config = AutoConfig.from_json_file(str(expand_path(bert_config_file))) if attention_probs_keep_prob is not None: bert_config.attention_probs_dropout_prob = 1.0 - attention_probs_keep_prob if hidden_keep_prob is not None: bert_config.hidden_dropout_prob = 1.0 - hidden_keep_prob model = AutoModelForTokenClassification(config=bert_config) else: raise ConfigError("No pre-trained BERT model is given.") self.crf = CRF(n_tags) if use_crf else None super().__init__(model, **kwargs)
[docs] def train_on_batch(self, input_ids: Union[List[List[int]], np.ndarray], input_masks: Union[List[List[int]], np.ndarray], y_masks: Union[List[List[int]], np.ndarray], y: List[List[int]], *args, **kwargs) -> Dict[str, float]: """ Args: input_ids: batch of indices of subwords input_masks: batch of masks which determine what should be attended args: arguments passed to _build_feed_dict and corresponding to additional input and output tensors of the derived class. kwargs: keyword arguments passed to _build_feed_dict and corresponding to additional input and output tensors of the derived class. Returns: dict with fields 'loss', 'head_learning_rate', and 'bert_learning_rate' """ b_input_ids = torch.from_numpy(input_ids).to(self.device) b_input_masks = torch.from_numpy(input_masks).to(self.device) subtoken_labels = [token_labels_to_subtoken_labels(y_el, y_mask, input_mask) for y_el, y_mask, input_mask in zip(y, y_masks, input_masks)] b_labels = torch.from_numpy(np.array(subtoken_labels)).to(torch.int64).to(self.device) self.optimizer.zero_grad() loss = self.model(input_ids=b_input_ids, attention_mask=b_input_masks, labels=b_labels).loss if self.crf is not None: self.crf(y, y_masks) if self.is_data_parallel: loss = loss.mean() self._make_step(loss) return {'loss': loss.item()}
[docs] def __call__(self, input_ids: Union[List[List[int]], np.ndarray], input_masks: Union[List[List[int]], np.ndarray], y_masks: Union[List[List[int]], np.ndarray]) -> Tuple[List[List[int]], List[np.ndarray]]: """ Predicts tag indices for a given subword tokens batch Args: input_ids: indices of the subwords input_masks: mask that determines where to attend and where not to y_masks: mask which determines the first subword units in the the word Returns: Label indices or class probabilities for each token (not subtoken) """ b_input_ids = torch.from_numpy(input_ids).to(self.device) b_input_masks = torch.from_numpy(input_masks).to(self.device) with torch.no_grad(): # Forward pass, calculate logit predictions logits = self.model(b_input_ids, attention_mask=b_input_masks) # Move logits and labels to CPU and to numpy arrays logits = token_from_subtoken(logits[0].detach().cpu(), torch.from_numpy(y_masks)) probas = torch.nn.functional.softmax(logits, dim=-1) probas = probas.detach().cpu().numpy() if self.crf is not None: logits = logits.transpose(1, 0).to(self.device) pred = self.crf.decode(logits) else: logits = logits.detach().cpu().numpy() pred = np.argmax(logits, axis=-1) seq_lengths = np.sum(y_masks, axis=1) pred = [p[:l] for l, p in zip(seq_lengths, pred)] return pred, probas
def load(self, fname=None): super().load(fname) if self.crf is not None: self.crf = self.crf.to(self.device) if self.load_path: weights_path_crf = Path(f"{self.load_path}_crf").resolve() weights_path_crf = weights_path_crf.with_suffix(".pth.tar") if weights_path_crf.exists(): checkpoint = torch.load(weights_path_crf, map_location=self.device) self.crf.load_state_dict(checkpoint["model_state_dict"], strict=False) else: log.warning(f"Init from scratch. Load path {weights_path_crf} does not exist.") def save(self, fname: Optional[str] = None, *args, **kwargs) -> None: super().save(fname, *args, **kwargs) if self.crf is not None: if fname is None: fname = self.save_path weights_path_crf = Path(f"{fname}_crf").resolve() weights_path_crf = weights_path_crf.with_suffix(".pth.tar") torch.save({"model_state_dict": self.crf.cpu().state_dict()}, weights_path_crf) self.crf.to(self.device)