# Copyright 2017 Neural Networks and Deep Learning lab, MIPT
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import re
import random
from collections import defaultdict
from dataclasses import dataclass
from logging import getLogger
from pathlib import Path
import torch
from typing import Tuple, List, Optional, Union, Dict, Set
import numpy as np
from transformers import AutoTokenizer
from transformers.data.processors.utils import InputFeatures
from deeppavlov.core.commands.utils import expand_path
from deeppavlov.core.common.registry import register
from deeppavlov.core.data.utils import zero_pad
from deeppavlov.core.models.component import Component
from deeppavlov.models.preprocessors.mask import Mask
log = getLogger(__name__)
@register('torch_transformers_multiplechoice_preprocessor')
class TorchTransformersMultiplechoicePreprocessor(Component):
"""Tokenize text on subtokens, encode subtokens with their indices, create tokens and segment masks.
Check details in :func:`bert_dp.preprocessing.convert_examples_to_features` function.
Args:
vocab_file: path to vocabulary
do_lower_case: set True if lowercasing is needed
max_seq_length: max sequence length in subtokens, including [SEP] and [CLS] tokens
return_tokens: whether to return tuple of input features and tokens, or only input features
Attributes:
max_seq_length: max sequence length in subtokens, including [SEP] and [CLS] tokens
return_tokens: whether to return tuple of input features and tokens, or only input features
tokenizer: instance of Bert FullTokenizer
"""
def __init__(self,
vocab_file: str,
do_lower_case: bool = True,
max_seq_length: int = 512,
return_tokens: bool = False,
**kwargs) -> None:
self.max_seq_length = max_seq_length
self.return_tokens = return_tokens
if Path(vocab_file).is_file():
vocab_file = str(expand_path(vocab_file))
self.tokenizer = AutoTokenizer(vocab_file=vocab_file,
do_lower_case=do_lower_case)
else:
self.tokenizer = AutoTokenizer.from_pretrained(vocab_file, do_lower_case=do_lower_case)
def tokenize_mc_examples(self,
contexts: List[List[str]],
choices: List[List[str]]) -> Dict[str, torch.tensor]:
num_choices = len(contexts[0])
batch_size = len(contexts)
# tokenize examples in groups of `num_choices`
examples = []
for context_list, choice_list in zip(contexts, choices):
for context, choice in zip(context_list, choice_list):
tokenized_input = self.tokenizer.encode_plus(text=context,
text_pair=choice,
return_attention_mask=True,
add_special_tokens=True,
truncation=True)
examples.append(tokenized_input)
padded_examples = self.tokenizer.pad(
examples,
padding=True,
max_length=self.max_seq_length,
return_tensors='pt',
)
padded_examples = {k: v.view(batch_size, num_choices, -1) for k, v in padded_examples.items()}
return padded_examples
def __call__(self, texts_a: List[List[str]], texts_b: List[List[str]] = None) -> Dict[str, torch.tensor]:
"""Tokenize and create masks.
texts_a and texts_b are separated by [SEP] token
Args:
texts_a: list of texts,
texts_b: list of texts, it could be None, e.g. single sentence classification task
Returns:
batch of :class:`transformers.data.processors.utils.InputFeatures` with subtokens, subtoken ids, \
subtoken mask, segment mask, or tuple of batch of InputFeatures and Batch of subtokens
"""
input_features = self.tokenize_mc_examples(texts_a, texts_b)
return input_features
@register('torch_squad_transformers_preprocessor')
class TorchSquadTransformersPreprocessor(Component):
"""Tokenize text on subtokens, encode subtokens with their indices, create tokens and segment masks.
Check details in :func:`bert_dp.preprocessing.convert_examples_to_features` function.
Args:
vocab_file: path to vocabulary
do_lower_case: set True if lowercasing is needed
max_seq_length: max sequence length in subtokens, including [SEP] and [CLS] tokens
return_tokens: whether to return tuple of input features and tokens, or only input features
Attributes:
max_seq_length: max sequence length in subtokens, including [SEP] and [CLS] tokens
return_tokens: whether to return tuple of input features and tokens, or only input features
tokenizer: instance of Bert FullTokenizer
"""
def __init__(self,
vocab_file: str,
do_lower_case: bool = True,
max_seq_length: int = 512,
return_tokens: bool = False,
add_token_type_ids: bool = False,
**kwargs) -> None:
self.max_seq_length = max_seq_length
self.return_tokens = return_tokens
self.add_token_type_ids = add_token_type_ids
if Path(vocab_file).is_file():
vocab_file = str(expand_path(vocab_file))
self.tokenizer = AutoTokenizer(vocab_file=vocab_file,
do_lower_case=do_lower_case)
else:
self.tokenizer = AutoTokenizer.from_pretrained(vocab_file, do_lower_case=do_lower_case)
def __call__(self, texts_a: List[str], texts_b: Optional[List[str]] = None) -> Union[List[InputFeatures],
Tuple[List[InputFeatures],
List[List[str]]]]:
"""Tokenize and create masks.
texts_a and texts_b are separated by [SEP] token
Args:
texts_a: list of texts,
texts_b: list of texts, it could be None, e.g. single sentence classification task
Returns:
batch of :class:`transformers.data.processors.utils.InputFeatures` with subtokens, subtoken ids, \
subtoken mask, segment mask, or tuple of batch of InputFeatures and Batch of subtokens
"""
if texts_b is None:
texts_b = [None] * len(texts_a)
input_features = []
tokens = []
for text_a, text_b in zip(texts_a, texts_b):
encoded_dict = self.tokenizer.encode_plus(
text=text_a, text_pair=text_b,
add_special_tokens=True,
max_length=self.max_seq_length,
truncation=True,
padding='max_length',
return_attention_mask=True,
return_tensors='pt')
if 'token_type_ids' not in encoded_dict:
if self.add_token_type_ids:
input_ids = encoded_dict['input_ids']
seq_len = input_ids.size(1)
sep = torch.where(input_ids == self.tokenizer.sep_token_id)[1][0].item()
len_a = min(sep + 1, seq_len)
len_b = seq_len - len_a
encoded_dict['token_type_ids'] = torch.cat((torch.zeros(1, len_a, dtype=int),
torch.ones(1, len_b, dtype=int)), dim=1)
else:
encoded_dict['token_type_ids'] = torch.tensor([0])
curr_features = InputFeatures(input_ids=encoded_dict['input_ids'],
attention_mask=encoded_dict['attention_mask'],
token_type_ids=encoded_dict['token_type_ids'],
label=None)
input_features.append(curr_features)
if self.return_tokens:
tokens.append(self.tokenizer.convert_ids_to_tokens(encoded_dict['input_ids'][0]))
if self.return_tokens:
return input_features, tokens
else:
return input_features
[docs]@register('torch_bert_ranker_preprocessor')
class TorchBertRankerPreprocessor(TorchTransformersPreprocessor):
"""Tokenize text to sub-tokens, encode sub-tokens with their indices, create tokens and segment masks for ranking.
Builds features for a pair of context with each of the response candidates.
"""
[docs] def __call__(self, batch: List[List[str]]) -> List[List[InputFeatures]]:
"""Tokenize and create masks.
Args:
batch: list of elements where the first element represents the batch with contexts
and the rest of elements represent response candidates batches
Returns:
list of feature batches with subtokens, subtoken ids, subtoken mask, segment mask.
"""
if isinstance(batch[0], str):
batch = [batch]
cont_resp_pairs = []
if len(batch[0]) == 1:
contexts = batch[0]
responses_empt = [None] * len(batch)
cont_resp_pairs.append(zip(contexts, responses_empt))
else:
contexts = [el[0] for el in batch]
for i in range(1, len(batch[0])):
responses = []
for el in batch:
responses.append(el[i])
cont_resp_pairs.append(zip(contexts, responses))
input_features = []
for s in cont_resp_pairs:
sub_list_features = []
for context, response in s:
encoded_dict = self.tokenizer.encode_plus(
text=context, text_pair=response, add_special_tokens=True, max_length=self.max_seq_length,
pad_to_max_length=True, return_attention_mask=True, return_tensors='pt')
curr_features = InputFeatures(input_ids=encoded_dict['input_ids'],
attention_mask=encoded_dict['attention_mask'],
token_type_ids=encoded_dict['token_type_ids'],
label=None)
sub_list_features.append(curr_features)
input_features.append(sub_list_features)
return input_features
@dataclass
class RecordFlatExample:
"""Dataclass to store a flattened ReCoRD example. Contains `probability` for
a given `entity` candidate, as well as its label.
"""
index: str
label: int
probability: float
entity: str
@dataclass
class RecordNestedExample:
"""Dataclass to store a nested ReCoRD example. Contains a single predicted entity, as well as
a list of correct answers.
"""
index: str
prediction: str
answers: List[str]
@register("torch_record_postprocessor")
class TorchRecordPostprocessor:
"""Combines flat classification examples into nested examples. When called returns nested examples
that weren't previously returned during current iteration over examples.
Args:
is_binary: signifies whether the classifier uses binary classification head
Attributes:
record_example_accumulator: underling accumulator that transforms flat examples
total_examples: overall number of flat examples that must be processed during current iteration
"""
def __init__(self, is_binary: bool = False, *args, **kwargs):
self.record_example_accumulator: RecordExampleAccumulator = RecordExampleAccumulator()
self.total_examples: Optional[int, None] = None
self.is_binary: bool = is_binary
def __call__(self,
idx: List[str],
y: List[int],
y_pred_probas: np.ndarray,
entities: List[str],
num_examples: List[int],
*args,
**kwargs) -> List[RecordNestedExample]:
"""Postprocessor call
Args:
idx: list of string indices
y: list of integer labels
y_pred_probas: array of predicted probabilities
num_examples: list of duplicated total numbers of examples
Returns:
List[RecordNestedExample]: processed but not previously returned examples (may be empty in some cases)
"""
if not self.is_binary:
# if we have outputs for both classes `0` and `1`
y_pred_probas = y_pred_probas[:, 1]
if self.total_examples != num_examples[0]:
# start over if num_examples is different
# implying that a different split is being evaluated
self.reset_accumulator()
self.total_examples = num_examples[0]
for index, label, probability, entity in zip(idx, y, y_pred_probas, entities):
self.record_example_accumulator.add_flat_example(index, label, probability, entity)
self.record_example_accumulator.collect_nested_example(index)
if self.record_example_accumulator.examples_processed >= self.total_examples:
# start over if all examples were processed
self.reset_accumulator()
return self.record_example_accumulator.return_examples()
def reset_accumulator(self):
"""Reinitialize the underlying accumulator from scratch
"""
self.record_example_accumulator = RecordExampleAccumulator()
class RecordExampleAccumulator:
"""ReCoRD example accumulator
Attributes:
examples_processed: total number of examples processed so far
record_counter: number of examples processed for each index
nested_len: expected number of flat examples for a given index
flat_examples: stores flat examples
nested_examples: stores nested examples
collected_indices: indices of collected nested examples
returned_indices: indices that have been returned
"""
def __init__(self):
self.examples_processed: int = 0
self.record_counter: Dict[str, int] = defaultdict(lambda: 0)
self.nested_len: Dict[str, int] = dict()
self.flat_examples: Dict[str, List[RecordFlatExample]] = defaultdict(lambda: [])
self.nested_examples: Dict[str, RecordNestedExample] = dict()
self.collected_indices: Set[str] = set()
self.returned_indices: Set[str] = set()
def add_flat_example(self, index: str, label: int, probability: float, entity: str):
"""Add a single flat example to the accumulator
Args:
index: example index
label: example label (`-1` means that label is not available)
probability: predicted probability
entity: candidate entity
"""
self.flat_examples[index].append(RecordFlatExample(index, label, probability, entity))
if index not in self.nested_len:
self.nested_len[index] = self.get_expected_len(index)
self.record_counter[index] += 1
self.examples_processed += 1
def ready_to_nest(self, index: str) -> bool:
"""Checks whether all the flat examples for a given index were collected at this point.
Args:
index: the index of the candidate nested example
Returns:
bool: indicates whether the collected flat examples can be combined into a nested example
"""
return self.record_counter[index] == self.nested_len[index]
def collect_nested_example(self, index: str):
"""Combines a list of flat examples denoted by the given index into a single nested example
provided that all the necessary flat example have been collected by this time.
Args:
index: the index of the candidate nested example
"""
if self.ready_to_nest(index):
example_list: List[RecordFlatExample] = self.flat_examples[index]
entities: List[str] = []
labels: List[int] = []
probabilities: List[float] = []
answers: List[str] = []
for example in example_list:
entities.append(example.entity)
labels.append(example.label)
probabilities.append(example.probability)
if example.label == 1:
answers.append(example.entity)
prediction_index = np.argmax(probabilities)
prediction = entities[prediction_index]
self.nested_examples[index] = RecordNestedExample(index, prediction, answers)
self.collected_indices.add(index)
def return_examples(self) -> List[RecordNestedExample]:
"""Determines which nested example were not yet returned during the current evaluation
cycle and returns them. May return an empty list if there are no new nested examples
to return yet.
Returns:
List[RecordNestedExample]: zero or more nested examples
"""
indices_to_return: Set[str] = self.collected_indices.difference(self.returned_indices)
examples_to_return: List[RecordNestedExample] = []
for index in indices_to_return:
examples_to_return.append(self.nested_examples[index])
self.returned_indices.update(indices_to_return)
return examples_to_return
@staticmethod
def get_expected_len(index: str) -> int:
"""
Calculates the total number of flat examples denoted by the give index
Args:
index: the index to calculate the number of examples for
Returns:
int: the expected number of examples for this index
"""
return int(index.split("-")[-1])