Source code for deeppavlov.models.seq2seq_go_bot.bot

# Copyright 2017 Neural Networks and Deep Learning lab, MIPT
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import itertools

from deeppavlov.core.common.registry import register
from deeppavlov.core.models.component import Component
from deeppavlov.core.models.nn_model import NNModel
from deeppavlov.core.common.log import get_logger


log = get_logger(__name__)


[docs]@register("seq2seq_go_bot") class Seq2SeqGoalOrientedBot(NNModel): """ A goal-oriented bot based on a sequence-to-sequence rnn. For implementation details see :class:`~deeppavlov.models.seq2seq_go_bot.network.Seq2SeqGoalOrientedBotNetwork`. Pretrained for :class:`~deeppavlov.dataset_readers.kvret_reader.KvretDatasetReader` dataset. Parameters: network: object of :class:`~deeppavlov.models.seq2seq_go_bot.network.Seq2SeqGoalOrientedBotNetwork` class. source_vocab: vocabulary of input tokens. target_vocab: vocabulary of bot response tokens. start_of_sequence_token: token that defines start of input sequence. end_of_sequence_token: token that defines end of input sequence and start of output sequence. debug: whether to display debug output. **kwargs: parameters passed to parent :class:`~deeppavlov.core.models.nn_model.NNModel` class. """ def __init__(self, network: Component, source_vocab: Component, target_vocab: Component, start_of_sequence_token: str, end_of_sequence_token: str, debug: bool = False, save_path: str = None, **kwargs) -> None: super().__init__(save_path=save_path, **kwargs) self.network = network self.src_vocab = source_vocab self.tgt_vocab = target_vocab self.sos_token = start_of_sequence_token self.eos_token = end_of_sequence_token self.debug = debug def train_on_batch(self, *batch): b_enc_ins, b_src_lens = [], [] b_dec_ins, b_dec_outs, b_tgt_lens, b_tgt_weights = [], [], [], [] for x_tokens, dialog_id, y_tokens in zip(*batch): enc_in = self._encode_context(x_tokens, dialog_id) b_enc_ins.append(enc_in) b_src_lens.append(len(enc_in)) dec_in, dec_out = self._encode_response(y_tokens) b_dec_ins.append(dec_in) b_dec_outs.append(dec_out) b_tgt_lens.append(len(dec_out)) b_tgt_weights.append([1] * len(dec_out)) # Sequence padding max_src_len = max(b_src_lens) max_tgt_len = max(b_tgt_lens) for i, (src_len, tgt_len) in enumerate(zip(b_src_lens, b_tgt_lens)): src_padd_len = max_src_len - src_len tgt_padd_len = max_tgt_len - tgt_len b_enc_ins[i].extend([self.src_vocab[self.sos_token]] * src_padd_len) b_dec_ins[i].extend([self.tgt_vocab[self.eos_token]] * tgt_padd_len) b_dec_outs[i].extend([self.tgt_vocab[self.eos_token]] * tgt_padd_len) b_tgt_weights[i].extend([0] * tgt_padd_len) self.network.train_on_batch(b_enc_ins, b_dec_ins, b_dec_outs, b_src_lens, b_tgt_lens, b_tgt_weights) def _encode_context(self, tokens, dialog_id=None): if self.debug: log.debug("Context tokens = \"{}\"".format(tokens)) token_idxs = self.src_vocab(tokens) return token_idxs def _encode_response(self, tokens): if self.debug: log.debug("Response tokens = \"{}\"".format(tokens)) token_idxs = self.tgt_vocab(tokens) return ([self.tgt_vocab[self.sos_token]] + token_idxs, token_idxs + [self.tgt_vocab[self.eos_token]]) def __call__(self, *batch): return self._infer_on_batch(*batch) def _infer_on_batch(self, utters, dialog_ids=itertools.repeat(None)): def _filter(tokens): for t in tokens: if t == self.eos_token: break yield t # TODO: history as input b_enc_ins, b_src_lens = [], [] if (len(utters) == 1) and not utters[0]: utters = [['hi']] for utter, dialog_id in zip(utters, dialog_ids): enc_in = self._encode_context(utter, dialog_id) b_enc_ins.append(enc_in) b_src_lens.append(len(enc_in)) # Sequence padding max_src_len = max(b_src_lens) for i, src_len in enumerate(b_src_lens): src_padd_len = max_src_len - src_len b_enc_ins[i].extend([self.src_vocab[self.eos_token]] * src_padd_len) pred_idxs = self.network(b_enc_ins, b_src_lens) preds = [list(_filter(self.tgt_vocab(utter_idxs))) for utter_idxs in pred_idxs] if self.debug: print("Dialog prediction = \"{}\"".format(preds[-1])) return preds def save(self): self.network.save() def load(self): pass