Source code for deeppavlov.models.bert.bert_classifier

# Copyright 2017 Neural Networks and Deep Learning lab, MIPT
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from logging import getLogger
from typing import List, Dict, Union

import tensorflow as tf
from bert_dp.modeling import BertConfig, BertModel
from bert_dp.optimization import AdamWeightDecayOptimizer
from bert_dp.preprocessing import InputFeatures

from deeppavlov.core.commands.utils import expand_path
from deeppavlov.core.common.registry import register
from deeppavlov.core.models.tf_model import LRScheduledTFModel

logger = getLogger(__name__)


[docs]@register('bert_classifier') class BertClassifierModel(LRScheduledTFModel): """Bert-based model for text classification. It uses output from [CLS] token and predicts labels using linear transformation. Args: bert_config_file: path to Bert configuration file n_classes: number of classes keep_prob: dropout keep_prob for non-Bert layers one_hot_labels: set True if one-hot encoding for labels is used multilabel: set True if it is multi-label classification return_probas: set True if return class probabilites instead of most probable label needed attention_probs_keep_prob: keep_prob for Bert self-attention layers hidden_keep_prob: keep_prob for Bert hidden layers optimizer: name of tf.train.* optimizer or None for `AdamWeightDecayOptimizer` num_warmup_steps: weight_decay_rate: L2 weight decay for `AdamWeightDecayOptimizer` pretrained_bert: pretrained Bert checkpoint min_learning_rate: min value of learning rate if learning rate decay is used """ # TODO: add warmup # TODO: add head-only pre-training def __init__(self, bert_config_file, n_classes, keep_prob, one_hot_labels=False, multilabel=False, return_probas=False, attention_probs_keep_prob=None, hidden_keep_prob=None, optimizer=None, num_warmup_steps=None, weight_decay_rate=0.01, pretrained_bert=None, min_learning_rate=1e-06, **kwargs) -> None: super().__init__(**kwargs) self.return_probas = return_probas self.n_classes = n_classes self.min_learning_rate = min_learning_rate self.keep_prob = keep_prob self.one_hot_labels = one_hot_labels self.multilabel = multilabel self.optimizer = optimizer self.num_warmup_steps = num_warmup_steps self.weight_decay_rate = weight_decay_rate if self.multilabel and not self.one_hot_labels: raise RuntimeError('Use one-hot encoded labels for multilabel classification!') if self.multilabel and not self.return_probas: raise RuntimeError('Set return_probas to True for multilabel classification!') self.bert_config = BertConfig.from_json_file(str(expand_path(bert_config_file))) if attention_probs_keep_prob is not None: self.bert_config.attention_probs_dropout_prob = 1.0 - attention_probs_keep_prob if hidden_keep_prob is not None: self.bert_config.hidden_dropout_prob = 1.0 - hidden_keep_prob self.sess_config = tf.ConfigProto(allow_soft_placement=True) self.sess_config.gpu_options.allow_growth = True self.sess = tf.Session(config=self.sess_config) self._init_graph() self._init_optimizer() self.sess.run(tf.global_variables_initializer()) if pretrained_bert is not None: pretrained_bert = str(expand_path(pretrained_bert)) if tf.train.checkpoint_exists(pretrained_bert) \ and not tf.train.checkpoint_exists(str(self.load_path.resolve())): logger.info('[initializing model with Bert from {}]'.format(pretrained_bert)) # Exclude optimizer and classification variables from saved variables var_list = self._get_saveable_variables( exclude_scopes=('Optimizer', 'learning_rate', 'momentum', 'output_weights', 'output_bias')) saver = tf.train.Saver(var_list) saver.restore(self.sess, pretrained_bert) if self.load_path is not None: self.load() def _init_graph(self): self._init_placeholders() self.bert = BertModel(config=self.bert_config, is_training=self.is_train_ph, input_ids=self.input_ids_ph, input_mask=self.input_masks_ph, token_type_ids=self.token_types_ph, use_one_hot_embeddings=False, ) output_layer = self.bert.get_pooled_output() hidden_size = output_layer.shape[-1].value output_weights = tf.get_variable( "output_weights", [self.n_classes, hidden_size], initializer=tf.truncated_normal_initializer(stddev=0.02)) output_bias = tf.get_variable( "output_bias", [self.n_classes], initializer=tf.zeros_initializer()) with tf.variable_scope("loss"): output_layer = tf.nn.dropout(output_layer, keep_prob=self.keep_prob_ph) logits = tf.matmul(output_layer, output_weights, transpose_b=True) logits = tf.nn.bias_add(logits, output_bias) if self.one_hot_labels: one_hot_labels = self.y_ph else: one_hot_labels = tf.one_hot(self.y_ph, depth=self.n_classes, dtype=tf.float32) self.y_predictions = tf.argmax(logits, axis=-1) if not self.multilabel: log_probs = tf.nn.log_softmax(logits, axis=-1) self.y_probas = tf.nn.softmax(logits, axis=-1) per_example_loss = -tf.reduce_sum(one_hot_labels * log_probs, axis=-1) self.loss = tf.reduce_mean(per_example_loss) else: self.y_probas = tf.nn.sigmoid(logits) self.loss = tf.reduce_mean( tf.nn.sigmoid_cross_entropy_with_logits(labels=one_hot_labels, logits=logits)) def _init_placeholders(self): self.input_ids_ph = tf.placeholder(shape=(None, None), dtype=tf.int32, name='ids_ph') self.input_masks_ph = tf.placeholder(shape=(None, None), dtype=tf.int32, name='masks_ph') self.token_types_ph = tf.placeholder(shape=(None, None), dtype=tf.int32, name='token_types_ph') if not self.one_hot_labels: self.y_ph = tf.placeholder(shape=(None, ), dtype=tf.int32, name='y_ph') else: self.y_ph = tf.placeholder(shape=(None, self.n_classes), dtype=tf.float32, name='y_ph') self.learning_rate_ph = tf.placeholder_with_default(0.0, shape=[], name='learning_rate_ph') self.keep_prob_ph = tf.placeholder_with_default(1.0, shape=[], name='keep_prob_ph') self.is_train_ph = tf.placeholder_with_default(False, shape=[], name='is_train_ph') def _init_optimizer(self): with tf.variable_scope('Optimizer'): self.global_step = tf.get_variable('global_step', shape=[], dtype=tf.int32, initializer=tf.constant_initializer(0), trainable=False) # default optimizer for Bert is Adam with fixed L2 regularization if self.optimizer is None: self.train_op = self.get_train_op(self.loss, learning_rate=self.learning_rate_ph, optimizer=AdamWeightDecayOptimizer, weight_decay_rate=self.weight_decay_rate, beta_1=0.9, beta_2=0.999, epsilon=1e-6, exclude_from_weight_decay=["LayerNorm", "layer_norm", "bias"] ) else: self.train_op = self.get_train_op(self.loss, learning_rate=self.learning_rate_ph) if self.optimizer is None: new_global_step = self.global_step + 1 self.train_op = tf.group(self.train_op, [self.global_step.assign(new_global_step)]) def _build_feed_dict(self, input_ids, input_masks, token_types, y=None): feed_dict = { self.input_ids_ph: input_ids, self.input_masks_ph: input_masks, self.token_types_ph: token_types, } if y is not None: feed_dict.update({ self.y_ph: y, self.learning_rate_ph: max(self.get_learning_rate(), self.min_learning_rate), self.keep_prob_ph: self.keep_prob, self.is_train_ph: True, }) return feed_dict
[docs] def train_on_batch(self, features: List[InputFeatures], y: Union[List[int], List[List[int]]]) -> Dict: """Train model on given batch. This method calls train_op using features and y (labels). Args: features: batch of InputFeatures y: batch of labels (class id or one-hot encoding) Returns: dict with loss and learning_rate values """ input_ids = [f.input_ids for f in features] input_masks = [f.input_mask for f in features] input_type_ids = [f.input_type_ids for f in features] feed_dict = self._build_feed_dict(input_ids, input_masks, input_type_ids, y) _, loss = self.sess.run([self.train_op, self.loss], feed_dict=feed_dict) return {'loss': loss, 'learning_rate': feed_dict[self.learning_rate_ph]}
[docs] def __call__(self, features: List[InputFeatures]) -> Union[List[int], List[List[float]]]: """Make prediction for given features (texts). Args: features: batch of InputFeatures Returns: predicted classes or probabilities of each class """ input_ids = [f.input_ids for f in features] input_masks = [f.input_mask for f in features] input_type_ids = [f.input_type_ids for f in features] feed_dict = self._build_feed_dict(input_ids, input_masks, input_type_ids) if not self.return_probas: pred = self.sess.run(self.y_predictions, feed_dict=feed_dict) else: pred = self.sess.run(self.y_probas, feed_dict=feed_dict) return pred