# Copyright 2017 Neural Networks and Deep Learning lab, MIPT
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from logging import getLogger
from typing import List
import numpy as np
from keras import backend as K
from keras import losses
from keras.initializers import glorot_uniform, Orthogonal
from keras.layers import Input, LSTM, Embedding, GlobalMaxPooling1D, Lambda, Dense, Layer
from keras.layers.merge import Multiply
from keras.layers.wrappers import Bidirectional
from keras.models import Model
from keras.optimizers import Adam
from tensorflow.python.framework.ops import Tensor
from deeppavlov.core.common.registry import register
from deeppavlov.models.ranking.keras_siamese_model import KerasSiameseModel
log = getLogger(__name__)
[docs]@register('bilstm_nn')
class BiLSTMSiameseNetwork(KerasSiameseModel):
"""The class implementing a siamese neural network with BiLSTM and max pooling.
There is a possibility to use a binary cross-entropy loss as well as
a triplet loss with random or hard negative sampling.
Args:
len_vocab: A size of the vocabulary to build embedding layer.
seed: Random seed.
shared_weights: Whether to use shared weights in the model to encode ``contexts`` and ``responses``.
embedding_dim: Dimensionality of token (word) embeddings.
reccurent: A type of the RNN cell. Possible values are ``lstm`` and ``bilstm``.
hidden_dim: Dimensionality of the hidden state of the RNN cell. If ``reccurent`` equals ``bilstm``
``hidden_dim`` should be doubled to get the actual dimensionality.
max_pooling: Whether to use max-pooling operation to get ``context`` (``response``) vector representation.
If ``False``, the last hidden state of the RNN will be used.
triplet_loss: Whether to use a model with triplet loss.
If ``False``, a model with crossentropy loss will be used.
margin: A margin parameter for triplet loss. Only required if ``triplet_loss`` is set to ``True``.
hard_triplets: Whether to use hard triplets sampling to train the model
i.e. to choose negative samples close to positive ones.
If set to ``False`` random sampling will be used.
Only required if ``triplet_loss`` is set to ``True``.
"""
def __init__(self,
len_vocab: int,
seed: int = None,
shared_weights: bool = True,
embedding_dim: int = 300,
reccurent: str = "bilstm",
hidden_dim: int = 300,
max_pooling: bool = True,
triplet_loss: bool = True,
margin: float = 0.1,
hard_triplets: bool = False,
*args,
**kwargs) -> None:
self.toks_num = len_vocab
self.seed = seed
self.hidden_dim = hidden_dim
self.shared_weights = shared_weights
self.pooling = max_pooling
self.recurrent = reccurent
self.margin = margin
self.embedding_dim = embedding_dim
self.hard_triplets = hard_triplets
self.triplet_mode = triplet_loss
super(BiLSTMSiameseNetwork, self).__init__(*args, **kwargs)
def compile(self) -> None:
optimizer = Adam(lr=self.learning_rate)
if self.triplet_mode:
loss = self._triplet_loss
else:
loss = losses.binary_crossentropy
self.model.compile(loss=loss, optimizer=optimizer)
self.score_model = self.create_score_model()
def load_initial_emb_matrix(self) -> None:
log.info("[initializing new `{}`]".format(self.__class__.__name__))
if self.use_matrix:
if self.shared_weights:
self.model.get_layer(name="embedding").set_weights([self.emb_matrix])
else:
self.model.get_layer(name="embedding_a").set_weights([self.emb_matrix])
self.model.get_layer(name="embedding_b").set_weights([self.emb_matrix])
def embedding_layer(self) -> Layer:
out = Embedding(self.toks_num,
self.embedding_dim,
input_length=self.max_sequence_length,
trainable=True, name="embedding")
return out
def lstm_layer(self) -> Layer:
if self.pooling:
ret_seq = True
else:
ret_seq = False
ker_in = glorot_uniform(seed=self.seed)
rec_in = Orthogonal(seed=self.seed)
if self.recurrent == "bilstm" or self.recurrent is None:
out = Bidirectional(LSTM(self.hidden_dim,
input_shape=(self.max_sequence_length, self.embedding_dim,),
kernel_initializer=ker_in,
recurrent_initializer=rec_in,
return_sequences=ret_seq), merge_mode='concat')
elif self.recurrent == "lstm":
out = LSTM(self.hidden_dim,
input_shape=(self.max_sequence_length, self.embedding_dim,),
kernel_initializer=ker_in,
recurrent_initializer=rec_in,
return_sequences=ret_seq)
return out
def create_model(self) -> Model:
if self.use_matrix:
context = Input(shape=(self.max_sequence_length,))
response = Input(shape=(self.max_sequence_length,))
if self.shared_weights:
emb_layer_a = self.embedding_layer()
emb_layer_b = emb_layer_a
else:
emb_layer_a = self.embedding_layer()
emb_layer_b = self.embedding_layer()
emb_c = emb_layer_a(context)
emb_r = emb_layer_b(response)
else:
context = Input(shape=(self.max_sequence_length, self.embedding_dim,))
response = Input(shape=(self.max_sequence_length, self.embedding_dim,))
emb_c = context
emb_r = response
if self.shared_weights:
lstm_layer_a = self.lstm_layer()
lstm_layer_b = lstm_layer_a
else:
lstm_layer_a = self.lstm_layer()
lstm_layer_b = self.lstm_layer()
lstm_c = lstm_layer_a(emb_c)
lstm_r = lstm_layer_b(emb_r)
if self.pooling:
pooling_layer = GlobalMaxPooling1D(name="sentence_embedding")
lstm_c = pooling_layer(lstm_c)
lstm_r = pooling_layer(lstm_r)
if self.triplet_mode:
dist = Lambda(self._pairwise_distances)([lstm_c, lstm_r])
else:
dist = Lambda(self._diff_mult_dist)([lstm_c, lstm_r])
dist = Dense(1, activation='sigmoid', name="score_model")(dist)
model = Model([context, response], dist)
return model
def create_score_model(self) -> Model:
cr = self.model.inputs
if self.triplet_mode:
emb_c = self.model.get_layer("sentence_embedding").get_output_at(0)
emb_r = self.model.get_layer("sentence_embedding").get_output_at(1)
dist_score = Lambda(lambda x: self._euclidian_dist(x), name="score_model")
score = dist_score([emb_c, emb_r])
else:
score = self.model.get_layer("score_model").output
score = Lambda(lambda x: 1. - K.squeeze(x, -1))(score)
score = Lambda(lambda x: 1. - x)(score)
model = Model(cr, score)
return model
def _diff_mult_dist(self, inputs: List[Tensor]) -> Tensor:
input1, input2 = inputs
a = K.abs(input1-input2)
b = Multiply()(inputs)
return K.concatenate([input1, input2, a, b])
def _euclidian_dist(self, x_pair: List[Tensor]) -> Tensor:
x1_norm = K.l2_normalize(x_pair[0], axis=1)
x2_norm = K.l2_normalize(x_pair[1], axis=1)
diff = x1_norm - x2_norm
square = K.square(diff)
sum = K.sum(square, axis=1)
sum = K.clip(sum, min_value=1e-12, max_value=None)
dist = K.sqrt(sum) / 2.
return dist
def _pairwise_distances(self, inputs: List[Tensor]) -> Tensor:
emb_c, emb_r = inputs
bs = K.shape(emb_c)[0]
embeddings = K.concatenate([emb_c, emb_r], 0)
dot_product = K.dot(embeddings, K.transpose(embeddings))
square_norm = K.batch_dot(embeddings, embeddings, axes=1)
distances = K.transpose(square_norm) - 2.0 * dot_product + square_norm
distances = K.slice(distances, (0, bs), (bs, bs))
distances = K.clip(distances, 0.0, None)
mask = K.cast(K.equal(distances, 0.0), K.dtype(distances))
distances = distances + mask * 1e-16
distances = K.sqrt(distances)
distances = distances * (1.0 - mask)
return distances
def _triplet_loss(self, labels: Tensor, pairwise_dist: Tensor) -> Tensor :
y_true = K.squeeze(labels, axis=1)
"""Triplet loss function"""
if self.hard_triplets:
triplet_loss = self._batch_hard_triplet_loss(y_true, pairwise_dist)
else:
triplet_loss = self._batch_all_triplet_loss(y_true, pairwise_dist)
return triplet_loss
def _batch_all_triplet_loss(self, y_true: Tensor, pairwise_dist: Tensor) -> Tensor:
anchor_positive_dist = K.expand_dims(pairwise_dist, 2)
anchor_negative_dist = K.expand_dims(pairwise_dist, 1)
triplet_loss = anchor_positive_dist - anchor_negative_dist + self.margin
mask = self._get_triplet_mask(y_true, pairwise_dist)
triplet_loss = mask * triplet_loss
triplet_loss = K.clip(triplet_loss, 0.0, None)
valid_triplets = K.cast(K.greater(triplet_loss, 1e-16), K.dtype(triplet_loss))
num_positive_triplets = K.sum(valid_triplets)
triplet_loss = K.sum(triplet_loss) / (num_positive_triplets + 1e-16)
return triplet_loss
def _batch_hard_triplet_loss(self, y_true: Tensor, pairwise_dist: Tensor) -> Tensor:
mask_anchor_positive = self._get_anchor_positive_triplet_mask(y_true, pairwise_dist)
anchor_positive_dist = mask_anchor_positive * pairwise_dist
hardest_positive_dist = K.max(anchor_positive_dist, axis=1, keepdims=True)
mask_anchor_negative = self._get_anchor_negative_triplet_mask(y_true, pairwise_dist)
anchor_negative_dist = mask_anchor_negative * pairwise_dist
mask_anchor_negative = self._get_semihard_anchor_negative_triplet_mask(anchor_negative_dist,
hardest_positive_dist,
mask_anchor_negative)
max_anchor_negative_dist = K.max(pairwise_dist, axis=1, keepdims=True)
anchor_negative_dist = pairwise_dist + max_anchor_negative_dist * (1.0 - mask_anchor_negative)
hardest_negative_dist = K.min(anchor_negative_dist, axis=1, keepdims=True)
triplet_loss = K.clip(hardest_positive_dist - hardest_negative_dist + self.margin, 0.0, None)
triplet_loss = K.mean(triplet_loss)
return triplet_loss
def _get_triplet_mask(self, y_true: Tensor, pairwise_dist: Tensor) -> Tensor:
# mask label(a) != label(p)
mask1 = K.expand_dims(K.equal(K.expand_dims(y_true, 0), K.expand_dims(y_true, 1)), 2)
mask1 = K.cast(mask1, K.dtype(pairwise_dist))
# mask a == p
mask2 = K.expand_dims(K.not_equal(pairwise_dist, 0.0), 2)
mask2 = K.cast(mask2, K.dtype(pairwise_dist))
# mask label(n) == label(a)
mask3 = K.expand_dims(K.not_equal(K.expand_dims(y_true, 0), K.expand_dims(y_true, 1)), 1)
mask3 = K.cast(mask3, K.dtype(pairwise_dist))
return mask1 * mask2 * mask3
def _get_anchor_positive_triplet_mask(self, y_true: Tensor, pairwise_dist: Tensor) -> Tensor:
# mask label(a) != label(p)
mask1 = K.equal(K.expand_dims(y_true, 0), K.expand_dims(y_true, 1))
mask1 = K.cast(mask1, K.dtype(pairwise_dist))
# mask a == p
mask2 = K.not_equal(pairwise_dist, 0.0)
mask2 = K.cast(mask2, K.dtype(pairwise_dist))
return mask1 * mask2
def _get_anchor_negative_triplet_mask(self, y_true: Tensor, pairwise_dist: Tensor) -> Tensor:
# mask label(n) == label(a)
mask = K.not_equal(K.expand_dims(y_true, 0), K.expand_dims(y_true, 1))
mask = K.cast(mask, K.dtype(pairwise_dist))
return mask
def _get_semihard_anchor_negative_triplet_mask(self, negative_dist: Tensor,
hardest_positive_dist: Tensor,
mask_negative: Tensor) -> Tensor:
# mask max(dist(a,p)) < dist(a,n)
mask = K.greater(negative_dist, hardest_positive_dist)
mask = K.cast(mask, K.dtype(negative_dist))
mask_semihard = K.cast(K.expand_dims(K.greater(K.sum(mask, 1), 0.0), 1), K.dtype(negative_dist))
mask = mask_negative * (1 - mask_semihard) + mask * mask_semihard
return mask
def _predict_on_batch(self, batch: List[np.ndarray]) -> np.ndarray:
return self.score_model.predict_on_batch(x=batch)