Source code for deeppavlov.models.ranking.ranking_model

# Copyright 2017 Neural Networks and Deep Learning lab, MIPT
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


from overrides import overrides
from copy import deepcopy
import inspect
from functools import reduce
import operator
import numpy as np
import random
from nltk.tokenize import sent_tokenize, word_tokenize

from deeppavlov.core.common.attributes import check_attr_true
from deeppavlov.core.common.registry import register
from deeppavlov.core.models.nn_model import NNModel
from deeppavlov.models.ranking.ranking_network import RankingNetwork
from deeppavlov.models.ranking.insurance_dict import InsuranceDict
from deeppavlov.models.ranking.emb_dict import EmbDict
from deeppavlov.core.common.log import get_logger
from typing import Union, List, Tuple, Dict

log = get_logger(__name__)


[docs]@register('ranking_model') class RankingModel(NNModel): """Class to perform ranking. Args: vocab_name: A key word that indicates which subclass of the :class:`deeppavlov.models.ranking.ranking_dict.RankingDict` to use. hard_triplets_sampling: Whether to use hard triplets sampling to train the model i.e. to choose negative samples close to positive ones. hardest_positives: Whether to use only one hardest positive sample per each anchor sample. semi_hard_negatives: Whether hard negative samples should be further away from anchor samples than positive samples or not. update_embeddings: Whether to store and update context and response embeddings or not. interact_pred_num: The number of the most relevant contexts and responses which model returns in the `interact` regime. **kwargs: Other parameters. """ def __init__(self, vocab_name, hard_triplets_sampling: bool = False, hardest_positives: bool = False, semi_hard_negatives: bool = False, num_hardest_negatives: int = None, update_embeddings: bool = False, interact_pred_num: int = 3, **kwargs): # Parameters for parent classes save_path = kwargs.get('save_path', None) load_path = kwargs.get('load_path', None) train_now = kwargs.get('train_now', None) mode = kwargs.get('mode', None) super().__init__(save_path=save_path, load_path=load_path, train_now=train_now, mode=mode) self.hard_triplets_sampling = hard_triplets_sampling self.hardest_positives = hardest_positives self.semi_hard_negatives = semi_hard_negatives self.num_hardest_negatives = num_hardest_negatives self.upd_embs = update_embeddings self.interact_pred_num = interact_pred_num self.train_now = train_now self.vocab_name = vocab_name opt = deepcopy(kwargs) if self.vocab_name == "insurance": dict_parameter_names = list(inspect.signature(InsuranceDict.__init__).parameters) dict_parameters = {par: opt[par] for par in dict_parameter_names if par in opt} self.dict = InsuranceDict(**dict_parameters, update_embeddings=update_embeddings) embdict_parameter_names = list(inspect.signature(EmbDict.__init__).parameters) embdict_parameters = {par: opt[par] for par in embdict_parameter_names if par in opt} self.embdict= EmbDict(**embdict_parameters) network_parameter_names = list(inspect.signature(RankingNetwork.__init__).parameters) self.network_parameters = {par: opt[par] for par in network_parameter_names if par in opt} self.load() train_parameters_names = list(inspect.signature(self._net.train_on_batch).parameters) self.train_parameters = {par: opt[par] for par in train_parameters_names if par in opt}
[docs] @overrides def load(self): """Load the model from the last checkpoint.""" if not self.load_path.exists(): log.info("[initializing new `{}`]".format(self.__class__.__name__)) self.dict.init_from_scratch() self.embdict.init_from_scratch(self.dict.tok2int_vocab) if hasattr(self.dict, 'char2int_vocab'): chars_num = len(self.dict.char2int_vocab) else: chars_num = 0 self._net = RankingNetwork(chars_num=chars_num, toks_num=len(self.dict.tok2int_vocab), emb_dict=self.embdict, **self.network_parameters) self._net.init_from_scratch(self.embdict.emb_matrix) else: log.info("[initializing `{}` from saved]".format(self.__class__.__name__)) self.dict.load() self.embdict.load() if hasattr(self.dict, 'char2int_vocab'): chars_num = len(self.dict.char2int_vocab) else: chars_num = 0 self._net = RankingNetwork(chars_num=chars_num, toks_num=len(self.dict.tok2int_vocab), emb_dict=self.embdict, **self.network_parameters) self._net.load(self.load_path)
[docs] @overrides def save(self): """Save the model.""" log.info('[saving model to {}]'.format(self.save_path.resolve())) self._net.save(self.save_path) if self.upd_embs: self.set_embeddings() self.dict.save() self.embdict.save()
[docs] @check_attr_true('train_now') def train_on_batch(self, x: List[List[Tuple[int, int]]], y: List[int]): """Train the model on a batch.""" if self.upd_embs: self.reset_embeddings() if self.hard_triplets_sampling: b = self.make_hard_triplets(x, y, self._net) y = np.ones(len(b[0][0])) else: b = self.make_batch(x) for i in range(len(x[0])): c = self.dict.make_toks(b[i][0], type="context") c = self.dict.make_ints(c) b[i][0] = c r = self.dict.make_toks(b[i][1], type="response") r = self.dict.make_ints(r) b[i][1] = r self._net.train_on_batch(b, y)
def make_batch(self, x): sample_len = len(x[0]) b = [] for i in range(sample_len): c = [] r = [] for el in x: c.append(el[i][0]) r.append(el[i][1]) b.append([c, r]) return b def make_hard_triplets(self, x, y, net): samples = [[s[1] for s in el] for el in x] labels = y batch_size = len(samples) num_samples = len(samples[0]) samp = [y for el in samples for y in el] s = self.dict.make_toks(samp, type="context") s = self.dict.make_ints(s) embeddings = net.predict_embedding([s, s], 512, type='context') embeddings = embeddings / np.expand_dims(np.linalg.norm(embeddings, axis=1), axis=1) dot_product = embeddings @ embeddings.T square_norm = np.diag(dot_product) distances = np.expand_dims(square_norm, 0) - 2.0 * dot_product + np.expand_dims(square_norm, 1) distances = np.maximum(distances, 0.0) distances = np.sqrt(distances) mask_anchor_negative = np.expand_dims(np.repeat(labels, num_samples), 0)\ != np.expand_dims(np.repeat(labels, num_samples), 1) mask_anchor_negative = mask_anchor_negative.astype(float) max_anchor_negative_dist = np.max(distances, axis=1, keepdims=True) anchor_negative_dist = distances + max_anchor_negative_dist * (1.0 - mask_anchor_negative) if self.num_hardest_negatives is not None: hard = np.argsort(anchor_negative_dist, axis=1)[:, :self.num_hardest_negatives] ind = np.random.randint(self.num_hardest_negatives, size=batch_size * num_samples) hardest_negative_ind = hard[batch_size * num_samples * [True], ind] else: hardest_negative_ind = np.argmin(anchor_negative_dist, axis=1) mask_anchor_positive = np.expand_dims(np.repeat(labels, num_samples), 0) \ == np.expand_dims(np.repeat(labels, num_samples), 1) mask_anchor_positive = mask_anchor_positive.astype(float) anchor_positive_dist = mask_anchor_positive * distances c =[] rp = [] rn = [] hrds = [] if self.hardest_positives: if self.semi_hard_negatives: hardest_positive_ind = [] hardest_negative_ind = [] for p, n in zip(anchor_positive_dist, anchor_negative_dist): no_samples = True p_li = list(zip(p, np.arange(batch_size * num_samples), batch_size * num_samples * [True])) n_li = list(zip(n, np.arange(batch_size * num_samples), batch_size * num_samples * [False])) pn_li = sorted(p_li + n_li, key=lambda el: el[0]) for i, x in enumerate(pn_li): if not x[2]: for y in pn_li[:i][::-1]: if y[2] and y[0] > 0.0: assert (x[1] != y[1]) hardest_negative_ind.append(x[1]) hardest_positive_ind.append(y[1]) no_samples = False break if not no_samples: break if no_samples: print("There is no negative examples with distances greater than positive examples distances.") exit(0) else: if self.num_hardest_negatives is not None: hard = np.argsort(anchor_positive_dist, axis=1)[:, -self.num_hardest_negatives:] ind = np.random.randint(self.num_hardest_negatives, size=batch_size * num_samples) hardest_positive_ind = hard[batch_size * num_samples * [True], ind] else: hardest_positive_ind = np.argmax(anchor_positive_dist, axis=1) for i in range(batch_size): for j in range(num_samples): c.append(s[i*num_samples+j]) rp.append(s[hardest_positive_ind[i*num_samples+j]]) rn.append(s[hardest_negative_ind[i*num_samples+j]]) else: if self.semi_hard_negatives: for i in range(batch_size): for j in range(num_samples): for k in range(j+1, num_samples): c.append(s[i*num_samples+j]) c.append(s[i*num_samples+k]) rp.append(s[i*num_samples+k]) rp.append(s[i*num_samples+j]) n, hrd = self.get_semi_hard_negative_ind(i, j, k, distances, anchor_negative_dist, batch_size, num_samples) assert(n != i*num_samples+k) rn.append(s[n]) hrds.append(hrd) n, hrd = self.get_semi_hard_negative_ind(i, k, j, distances, anchor_negative_dist, batch_size, num_samples) assert(n != i*num_samples+j) rn.append(s[n]) hrds.append(hrd) else: for i in range(batch_size): for j in range(num_samples): for k in range(j + 1, num_samples): c.append(s[i * num_samples + j]) c.append(s[i * num_samples + k]) rp.append(s[i * num_samples + k]) rp.append(s[i * num_samples + j]) rn.append(s[hardest_negative_ind[i * num_samples + j]]) rn.append(s[hardest_negative_ind[i * num_samples + k]]) triplets = list(zip(c, rp, rn)) np.random.shuffle(triplets) c = [el[0] for el in triplets] rp = [el[1] for el in triplets] rn = [el[2] for el in triplets] ratio = sum(hrds) / len(hrds) print("Ratio of semi-hard negative samples is %f" % ratio) return [(c, rp), (c, rn)] def get_semi_hard_negative_ind(self, i, j, k, distances, anchor_negative_dist, batch_size, num_samples): anc_pos_dist = distances[i * num_samples + j, i * num_samples + k] neg_dists = anchor_negative_dist[i * num_samples + j] n_li_pre = sorted(list(zip(neg_dists, np.arange(batch_size * num_samples))), key=lambda el: el[0]) n_li = list(filter(lambda x: x[1]<i*num_samples, n_li_pre)) + \ list(filter(lambda x: x[1]>=(i+1)*num_samples, n_li_pre)) for x in n_li: if x[0] > anc_pos_dist : return x[1], True return random.choice(n_li)[1], False
[docs] def __call__(self, batch: Union[List[List[Tuple[int, int]]], List[str]]) ->\ Union[np.ndarray, Dict[str, List[str]]]: """Make a prediction on a batch.""" if type(batch[0]) == list: y_pred = [] b = self.make_batch(batch) for el in b: c = self.dict.make_toks(el[0], type="context") c = self.dict.make_ints(c) r = self.dict.make_toks(el[1], type="response") r = self.dict.make_ints(r) yp = self._net.predict_score_on_batch([c, r]) y_pred.append(yp) y_pred = np.hstack(y_pred) return y_pred elif type(batch[0]) == str: c_input = tokenize(batch) c_input = self.dict.make_ints(c_input) c_input_emb = self._net.predict_embedding_on_batch([c_input, c_input], type='context') c_emb = [self.dict.context2emb_vocab[i] for i in range(len(self.dict.context2emb_vocab))] c_emb = np.vstack(c_emb) pred_cont = np.sum(c_input_emb * c_emb, axis=1)\ / np.linalg.norm(c_input_emb, axis=1) / np.linalg.norm(c_emb, axis=1) pred_cont = np.flip(np.argsort(pred_cont), 0)[:self.interact_pred_num] pred_cont = [' '.join(self.dict.context2toks_vocab[el]) for el in pred_cont] r_emb = [self.dict.response2emb_vocab[i] for i in range(len(self.dict.response2emb_vocab))] r_emb = np.vstack(r_emb) pred_resp = np.sum(c_input_emb * r_emb, axis=1)\ / np.linalg.norm(c_input_emb, axis=1) / np.linalg.norm(r_emb, axis=1) pred_resp = np.flip(np.argsort(pred_resp), 0)[:self.interact_pred_num] pred_resp = [' '.join(self.dict.response2toks_vocab[el]) for el in pred_resp] y_pred = [{"contexts": pred_cont, "responses": pred_resp}] return y_pred
def update_embeddings(self, batch): sample_len = len(batch[0]) labels_cont = [] labels_resp = [] cont = [] resp = [] for i in range(sample_len): lc = [] lr = [] for el in batch: lc.append(el[i][0]) lr.append(el[i][1]) labels_cont.append(lc) labels_resp.append(lr) for i in range(sample_len): c = self.dict.make_toks(labels_cont[i], type="context") c = self.dict.make_ints(c) cont.append(c) r = self.dict.make_toks(labels_resp[i], type="response") r = self.dict.make_ints(r) resp.append(r) for el in zip(labels_cont, cont): c_emb = self._net.predict_embedding_on_batch([el[1], el[1]], type='context') for i in range(len(el[0])): self.dict.context2emb_vocab[el[0][i]] = c_emb[i] for el in zip(labels_resp, resp): r_emb = self._net.predict_embedding_on_batch([el[1], el[1]], type='response') for i in range(len(el[0])): self.dict.response2emb_vocab[el[0][i]] = r_emb[i] def set_embeddings(self): if self.dict.response2emb_vocab[0] is None: r = [] for i in range(len(self.dict.response2toks_vocab)): r.append(self.dict.response2toks_vocab[i]) r = self.dict.make_ints(r) response_embeddings = self._net.predict_embedding([r, r], 512, type='response') for i in range(len(self.dict.response2toks_vocab)): self.dict.response2emb_vocab[i] = response_embeddings[i] if self.dict.context2emb_vocab[0] is None: c = [] for i in range(len(self.dict.context2toks_vocab)): c.append(self.dict.context2toks_vocab[i]) c = self.dict.make_ints(c) context_embeddings = self._net.predict_embedding([c, c], 512, type='context') for i in range(len(self.dict.context2toks_vocab)): self.dict.context2emb_vocab[i] = context_embeddings[i] def reset_embeddings(self): if self.dict.response2emb_vocab[0] is not None: for i in range(len(self.dict.response2emb_vocab)): self.dict.response2emb_vocab[i] = None if self.dict.context2emb_vocab[0] is not None: for i in range(len(self.dict.context2emb_vocab)): self.dict.context2emb_vocab[i] = None def shutdown(self): pass def reset(self): pass
def tokenize(sen_list): sen_tokens_list = [] for sen in sen_list: sent_toks = sent_tokenize(sen) word_toks = [word_tokenize(el) for el in sent_toks] tokens = [val for sublist in word_toks for val in sublist] tokens = [el for el in tokens if el != ''] sen_tokens_list.append(tokens) return sen_tokens_list