Source code for deeppavlov.models.bert.bert_squad

# Copyright 2017 Neural Networks and Deep Learning lab, MIPT
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import json
import math
from logging import getLogger
from typing import List, Tuple, Optional, Dict

import numpy as np
import tensorflow as tf
from bert_dp.modeling import BertConfig, BertModel
from bert_dp.optimization import AdamWeightDecayOptimizer
from bert_dp.preprocessing import InputFeatures
from bert_dp.tokenization import FullTokenizer

from deeppavlov import build_model
from deeppavlov.core.commands.utils import expand_path
from deeppavlov.core.common.registry import register
from deeppavlov.core.models.estimator import Component
from deeppavlov.core.models.tf_model import LRScheduledTFModel
from deeppavlov.models.squad.utils import softmax_mask

logger = getLogger(__name__)


[docs]@register('squad_bert_model') class BertSQuADModel(LRScheduledTFModel): """Bert-based model for SQuAD-like problem setting: It predicts start and end position of answer for given question and context. [CLS] token is used as no_answer. If model selects [CLS] token as most probable answer, it means that there is no answer in given context. Start and end position of answer are predicted by linear transformation of Bert outputs. Args: bert_config_file: path to Bert configuration file keep_prob: dropout keep_prob for non-Bert layers attention_probs_keep_prob: keep_prob for Bert self-attention layers hidden_keep_prob: keep_prob for Bert hidden layers optimizer: name of tf.train.* optimizer or None for `AdamWeightDecayOptimizer` weight_decay_rate: L2 weight decay for `AdamWeightDecayOptimizer` pretrained_bert: pretrained Bert checkpoint min_learning_rate: min value of learning rate if learning rate decay is used """ def __init__(self, bert_config_file: str, keep_prob: float, attention_probs_keep_prob: Optional[float] = None, hidden_keep_prob: Optional[float] = None, optimizer: Optional[str] = None, weight_decay_rate: Optional[float] = 0.01, pretrained_bert: Optional[str] = None, min_learning_rate: float = 1e-06, **kwargs) -> None: super().__init__(**kwargs) self.min_learning_rate = min_learning_rate self.keep_prob = keep_prob self.optimizer = optimizer self.weight_decay_rate = weight_decay_rate self.bert_config = BertConfig.from_json_file(str(expand_path(bert_config_file))) if attention_probs_keep_prob is not None: self.bert_config.attention_probs_dropout_prob = 1.0 - attention_probs_keep_prob if hidden_keep_prob is not None: self.bert_config.hidden_dropout_prob = 1.0 - hidden_keep_prob self.sess_config = tf.ConfigProto(allow_soft_placement=True) self.sess_config.gpu_options.allow_growth = True self.sess = tf.Session(config=self.sess_config) self._init_graph() self._init_optimizer() self.sess.run(tf.global_variables_initializer()) if pretrained_bert is not None: pretrained_bert = str(expand_path(pretrained_bert)) if tf.train.checkpoint_exists(pretrained_bert) \ and not tf.train.checkpoint_exists(str(self.load_path.resolve())): logger.info('[initializing model with Bert from {}]'.format(pretrained_bert)) var_list = self._get_saveable_variables( exclude_scopes=('Optimizer', 'learning_rate', 'momentum', 'squad')) saver = tf.train.Saver(var_list) saver.restore(self.sess, pretrained_bert) if self.load_path is not None: self.load() def _init_graph(self): self._init_placeholders() seq_len = tf.shape(self.input_ids_ph)[-1] self.y_st = tf.one_hot(self.y_st_ph, depth=seq_len) self.y_end = tf.one_hot(self.y_end_ph, depth=seq_len) self.bert = BertModel(config=self.bert_config, is_training=self.is_train_ph, input_ids=self.input_ids_ph, input_mask=self.input_masks_ph, token_type_ids=self.token_types_ph, use_one_hot_embeddings=False, ) last_layer = self.bert.get_sequence_output() hidden_size = last_layer.get_shape().as_list()[-1] bs = tf.shape(last_layer)[0] with tf.variable_scope('squad'): output_weights = tf.get_variable('output_weights', [2, hidden_size], initializer=tf.truncated_normal_initializer(stddev=0.02)) output_bias = tf.get_variable('output_bias', [2], initializer=tf.zeros_initializer()) last_layer_rs = tf.reshape(last_layer, [-1, hidden_size]) logits = tf.matmul(last_layer_rs, output_weights, transpose_b=True) logits = tf.nn.bias_add(logits, output_bias) logits = tf.reshape(logits, [bs, -1, 2]) logits = tf.transpose(logits, [2, 0, 1]) logits_st, logits_end = tf.unstack(logits, axis=0) logit_mask = self.token_types_ph # [CLS] token is used as no answer mask = tf.concat([tf.ones((bs, 1), dtype=tf.int32), tf.zeros((bs, seq_len - 1), dtype=tf.int32)], axis=-1) logit_mask = logit_mask + mask logits_st = softmax_mask(logits_st, logit_mask) logits_end = softmax_mask(logits_end, logit_mask) start_probs = tf.nn.softmax(logits_st) end_probs = tf.nn.softmax(logits_end) outer = tf.matmul(tf.expand_dims(start_probs, axis=2), tf.expand_dims(end_probs, axis=1)) outer_logits = tf.exp(tf.expand_dims(logits_st, axis=2) + tf.expand_dims(logits_end, axis=1)) context_max_len = tf.reduce_max(tf.reduce_sum(self.token_types_ph, axis=1)) max_ans_length = tf.cast(tf.minimum(20, context_max_len), tf.int64) outer = tf.matrix_band_part(outer, 0, max_ans_length) outer_logits = tf.matrix_band_part(outer_logits, 0, max_ans_length) self.yp_score = 1 - tf.nn.softmax(logits_st)[:, 0] * tf.nn.softmax(logits_end)[:, 0] self.start_probs = start_probs self.end_probs = end_probs self.start_pred = tf.argmax(tf.reduce_max(outer, axis=2), axis=1) self.end_pred = tf.argmax(tf.reduce_max(outer, axis=1), axis=1) self.yp_logits = tf.reduce_max(tf.reduce_max(outer_logits, axis=2), axis=1) with tf.variable_scope("loss"): loss_st = tf.nn.softmax_cross_entropy_with_logits(logits=logits_st, labels=self.y_st) loss_end = tf.nn.softmax_cross_entropy_with_logits(logits=logits_end, labels=self.y_end) self.loss = tf.reduce_mean(loss_st + loss_end) def _init_placeholders(self): self.input_ids_ph = tf.placeholder(shape=(None, None), dtype=tf.int32, name='ids_ph') self.input_masks_ph = tf.placeholder(shape=(None, None), dtype=tf.int32, name='masks_ph') self.token_types_ph = tf.placeholder(shape=(None, None), dtype=tf.int32, name='token_types_ph') self.y_st_ph = tf.placeholder(shape=(None,), dtype=tf.int32, name='y_st_ph') self.y_end_ph = tf.placeholder(shape=(None,), dtype=tf.int32, name='y_end_ph') self.learning_rate_ph = tf.placeholder_with_default(0.0, shape=[], name='learning_rate_ph') self.keep_prob_ph = tf.placeholder_with_default(1.0, shape=[], name='keep_prob_ph') self.is_train_ph = tf.placeholder_with_default(False, shape=[], name='is_train_ph') def _init_optimizer(self): with tf.variable_scope('Optimizer'): self.global_step = tf.get_variable('global_step', shape=[], dtype=tf.int32, initializer=tf.constant_initializer(0), trainable=False) # default optimizer for Bert is Adam with fixed L2 regularization if self.optimizer is None: self.train_op = self.get_train_op(self.loss, learning_rate=self.learning_rate_ph, optimizer=AdamWeightDecayOptimizer, weight_decay_rate=self.weight_decay_rate, beta_1=0.9, beta_2=0.999, epsilon=1e-6, exclude_from_weight_decay=["LayerNorm", "layer_norm", "bias"] ) else: self.train_op = self.get_train_op(self.loss, learning_rate=self.learning_rate_ph) if self.optimizer is None: new_global_step = self.global_step + 1 self.train_op = tf.group(self.train_op, [self.global_step.assign(new_global_step)]) def _build_feed_dict(self, input_ids, input_masks, token_types, y_st=None, y_end=None): feed_dict = { self.input_ids_ph: input_ids, self.input_masks_ph: input_masks, self.token_types_ph: token_types, } if y_st is not None and y_end is not None: feed_dict.update({ self.y_st_ph: y_st, self.y_end_ph: y_end, self.learning_rate_ph: max(self.get_learning_rate(), self.min_learning_rate), self.keep_prob_ph: self.keep_prob, self.is_train_ph: True, }) return feed_dict
[docs] def train_on_batch(self, features: List[InputFeatures], y_st: List[List[int]], y_end: List[List[int]]) -> Dict: """Train model on given batch. This method calls train_op using features and labels from y_st and y_end Args: features: batch of InputFeatures instances y_st: batch of lists of ground truth answer start positions y_end: batch of lists of ground truth answer end positions Returns: dict with loss and learning_rate values """ input_ids = [f.input_ids for f in features] input_masks = [f.input_mask for f in features] input_type_ids = [f.input_type_ids for f in features] y_st = [x[0] for x in y_st] y_end = [x[0] for x in y_end] feed_dict = self._build_feed_dict(input_ids, input_masks, input_type_ids, y_st, y_end) _, loss = self.sess.run([self.train_op, self.loss], feed_dict=feed_dict) return {'loss': loss, 'learning_rate': feed_dict[self.learning_rate_ph]}
[docs] def __call__(self, features: List[InputFeatures]) -> Tuple[List[int], List[int], List[float], List[float]]: """get predictions using features as input Args: features: batch of InputFeatures instances Returns: predictions: start, end positions, logits for answer and no_answer score """ input_ids = [f.input_ids for f in features] input_masks = [f.input_mask for f in features] input_type_ids = [f.input_type_ids for f in features] feed_dict = self._build_feed_dict(input_ids, input_masks, input_type_ids) st, end, logits, scores = self.sess.run([self.start_pred, self.end_pred, self.yp_logits, self.yp_score], feed_dict=feed_dict) return st, end, logits.tolist(), scores.tolist()
[docs]@register('squad_bert_infer') class BertSQuADInferModel(Component): """This model wraps BertSQuADModel to make predictions on longer than 512 tokens sequences. It splits context on chunks with `max_seq_length - 3 - len(question)` length, preserving sentences boundaries. It reassembles batches with chunks instead of full contexts to optimize performance, e.g.,: batch_size = 5 number_of_contexts == 2 number of first context chunks == 8 number of second context chunks == 2 we will create two batches with 5 chunks For each context the best answer is selected via logits or scores from BertSQuADModel. Args: squad_model_config: path to DeepPavlov BertSQuADModel config file vocab_file: path to Bert vocab file do_lower_case: set True if lowercasing is needed max_seq_length: max sequence length in subtokens, including [SEP] and [CLS] tokens batch_size: size of batch to use during inference lang: either `en` or `ru`, it is used to select sentence tokenizer """ def __init__(self, squad_model_config: str, vocab_file: str, do_lower_case: bool, max_seq_length: int = 512, batch_size: int = 10, lang='en', **kwargs) -> None: config = json.load(open(squad_model_config)) config['chainer']['pipe'][0]['max_seq_length'] = max_seq_length self.model = build_model(config) self.max_seq_length = max_seq_length vocab_file = str(expand_path(vocab_file)) self.tokenizer = FullTokenizer(vocab_file=vocab_file, do_lower_case=do_lower_case) self.batch_size = batch_size if lang == 'en': from nltk import sent_tokenize self.sent_tokenizer = sent_tokenize elif lang == 'ru': from ru_sent_tokenize import ru_sent_tokenize self.sent_tokenizer = ru_sent_tokenize else: raise RuntimeError('en and ru languages are supported only')
[docs] def __call__(self, contexts: List[str], questions: List[str], **kwargs) -> Tuple[List[str], List[int], List[float]]: """get predictions for given contexts and questions Args: contexts: batch of contexts questions: batch of questions Returns: predictions: answer, answer start position, logits or scores """ batch_indices = [] contexts_to_predict = [] questions_to_predict = [] predictions = {} for i, (context, question) in enumerate(zip(contexts, questions)): context_subtokens = self.tokenizer.tokenize(context) question_subtokens = self.tokenizer.tokenize(question) max_chunk_len = self.max_seq_length - len(question_subtokens) - 3 if 0 < max_chunk_len < len(context_subtokens): number_of_chunks = math.ceil(len(context_subtokens) / max_chunk_len) sentences = self.sent_tokenizer(context) for chunk in np.array_split(sentences, number_of_chunks): contexts_to_predict += [' '.join(chunk)] questions_to_predict += [question] batch_indices += [i] else: contexts_to_predict += [context] questions_to_predict += [question] batch_indices += [i] for j in range(0, len(contexts_to_predict), self.batch_size): c_batch = contexts_to_predict[j: j + self.batch_size] q_batch = questions_to_predict[j: j + self.batch_size] ind_batch = batch_indices[j: j + self.batch_size] a_batch, a_st_batch, logits_batch = self.model(c_batch, q_batch) for a, a_st, logits, ind in zip(a_batch, a_st_batch, logits_batch, ind_batch): if ind in predictions: predictions[ind] += [(a, a_st, logits)] else: predictions[ind] = [(a, a_st, logits)] answers, answer_starts, logits = [], [], [] for ind in sorted(predictions.keys()): prediction = predictions[ind] best_answer_ind = np.argmax([p[2] for p in prediction]) answers += [prediction[best_answer_ind][0]] answer_starts += [prediction[best_answer_ind][1]] logits += [prediction[best_answer_ind][2]] return answers, answer_starts, logits