Source code for deeppavlov.models.go_bot.bot

# Copyright 2017 Neural Networks and Deep Learning lab, MIPT
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import re
from typing import Dict, Any
import numpy as np

from deeppavlov.core.commands.utils import expand_path
from deeppavlov.core.common.registry import register
from deeppavlov.core.models.component import Component
from deeppavlov.core.models.nn_model import NNModel
from deeppavlov.core.common.log import get_logger
from deeppavlov.models.go_bot.tracker import Tracker
from deeppavlov.models.go_bot.network import GoalOrientedBotNetwork
import deeppavlov.models.go_bot.templates as templ


log = get_logger(__name__)


[docs]@register("go_bot") class GoalOrientedBot(NNModel): """ The dialogue bot is based on https://arxiv.org/abs/1702.03274, which introduces Hybrid Code Networks that combine an RNN with domain-specific knowledge and system action templates. Todo: add docstring for trackers. Parameters: tokenizer: one of tokenizers from :doc:`deeppavlov.models.tokenizers </apiref/models/tokenizers>` module. tracker: dialogue state tracker from :doc:`deeppavlov.models.go_bot.tracker </apiref/models/go_bot>`. network_parameters: initialization parameters for policy network (see :class:`~deeppavlov.models.go_bot.network.GoalOrientedBotNetwork`). template_path: file with mapping between actions and text templates for response generation. template_type: type of used response templates in string format. word_vocab: vocabulary of input word tokens (:class:`~deeppavlov.core.data.vocab.DefaultVocabulary` recommended). bow_embedder: instance of one-hot word encoder :class:`~deeppavlov.models.embedders.bow_embedder.BoWEmbedder`. embedder: one of embedders from :doc:`deeppavlov.models.embedders </apiref/models/embedders>` module. slot_filler: component that outputs slot values for a given utterance (:class:`~deeppavlov.models.slotfill.slotfill.DstcSlotFillingNetwork` recommended). intent_classifier: component that outputs intents probability distribution for a given utterance ( :class:`~deeppavlov.models.classifiers.keras_classification_model.KerasClassificationModel` recommended). database: database that will be used during inference to perform ``api_call_action`` actions and get ``'db_result'`` result ( :class:`~deeppavlov.core.data.sqlite_database.Sqlite3Database` recommended). api_call_action: label of the action that corresponds to database api call (it must be present in your ``template_path`` file), during interaction it will be used to get ``'db_result'`` from ``database``. use_action_mask: if ``True``, network output will be applied with a mask over allowed actions. debug: whether to display debug output. """ def __init__(self, tokenizer: Component, tracker: Tracker, network_parameters: Dict[str, Any], template_path: str, save_path: str, load_path: str = None, template_type: str = "DefaultTemplate", word_vocab: Component = None, bow_embedder: Component = None, embedder: Component = None, slot_filler: Component = None, intent_classifier: Component = None, database: Component = None, api_call_action: str = None, # TODO: make it unrequired use_action_mask: bool = False, debug: bool = False, **kwargs): super().__init__(load_path=load_path, save_path=save_path, **kwargs) self.tokenizer = tokenizer self.tracker = tracker self.bow_embedder = bow_embedder self.embedder = embedder self.slot_filler = slot_filler self.intent_classifier = intent_classifier self.use_action_mask = use_action_mask self.debug = debug self.word_vocab = word_vocab template_path = expand_path(template_path) template_type = getattr(templ, template_type) log.info("[loading templates from {}]".format(template_path)) self.templates = templ.Templates(template_type).load(template_path) self.n_actions = len(self.templates) log.info("{} templates loaded".format(self.n_actions)) self.database = database self.api_call_id = None if api_call_action is not None: self.api_call_id = self.templates.actions.index(api_call_action) self.intents = [] if callable(self.intent_classifier): # intent_classifier returns y_probas self.intents = self.intent_classifier.get_main_component().classes network_parameters['load_path'] = load_path network_parameters['save_path'] = save_path self.network = self._init_network(network_parameters) self.reset() def _init_network(self, params): # initialize network obs_size = 6 + self.tracker.num_features + self.n_actions if callable(self.bow_embedder): obs_size += len(self.word_vocab) if callable(self.embedder): obs_size += self.embedder.dim if callable(self.intent_classifier): obs_size += len(self.intents) log.info("Calculated input size for `GoalOrientedBotNetwork` is {}" .format(obs_size)) if 'obs_size' not in params: params['obs_size'] = obs_size if 'action_size' not in params: params['action_size'] = self.n_actions attn = params.get('attention_mechanism') if attn: attn['token_size'] = attn.get('token_size') or self.embedder.dim attn['action_as_key'] = attn.get('action_as_key', False) attn['intent_as_key'] = attn.get('intent_as_key', False) key_size = 0 if attn['action_as_key']: key_size += self.n_actions if attn['intent_as_key'] and callable(self.intent_classifier): key_size += len(self.intents) key_size = key_size or 1 attn['key_size'] = attn.get('key_size') or key_size params['attention_mechanism'] = attn return GoalOrientedBotNetwork(**params) def _encode_context(self, context, db_result=None): # tokenize input tokens = self.tokenizer([context.lower().strip()])[0] if self.debug: log.debug("Tokenized text= `{}`".format(' '.join(tokens))) # Bag of words features bow_features = [] if callable(self.bow_embedder): tokens_idx = self.word_vocab(tokens) bow_features = self.bow_embedder([tokens_idx])[0] bow_features = bow_features.astype(np.float32) # Embeddings emb_features = [] emb_context = np.array([], dtype=np.float32) if callable(self.embedder): if self.network.attn: if tokens: pad = np.zeros((self.network.attn.max_num_tokens, self.network.attn.token_size), dtype=np.float32) sen = np.array(self.embedder([tokens])[0]) # TODO : Unsupport of batch_size more than 1 emb_context = np.concatenate((pad, sen)) emb_context = emb_context[-self.network.attn.max_num_tokens:] else: emb_context = \ np.zeros((self.network.attn.max_num_tokens, self.network.attn.token_size), dtype=np.float32) else: emb_features = self.embedder([tokens], mean=True)[0] # random embedding instead of zeros if np.all(emb_features < 1e-20): emb_dim = self.embedder.dim emb_features = np.fabs(np.random.normal(0, 1/emb_dim, emb_dim)) # Intent features intent_features = [] if callable(self.intent_classifier): # intent, intent_probs = self.intent_classifier([context]) # intent_features = np.array([intent_probs[0][i] for i in self.intents], # dtype=np.float32) intent_features = np.array(self.intent_classifier([context]))[0] intent = [self.intents[np.argmax(intent_features[0])]] if self.debug: log.debug("Predicted intent = `{}`".format(intent[0])) attn_key = np.array([], dtype=np.float32) if self.network.attn: if self.network.attn.action_as_key: attn_key = np.hstack((attn_key, self.prev_action)) if self.network.attn.intent_as_key: attn_key = np.hstack((attn_key, intent_features)) if len(attn_key) == 0: attn_key = np.array([1], dtype=np.float32) # Text entity features if callable(self.slot_filler): self.tracker.update_state(self.slot_filler([tokens])[0]) if self.debug: log.debug("Slot vals: {}".format(self.slot_filler([tokens]))) state_features = self.tracker.get_features() # Other features result_matches_state = 0. if self.db_result is not None: result_matches_state = all(v == self.db_result.get(s) for s, v in self.tracker.get_state().items() if v != 'dontcare') * 1. context_features = np.array([bool(db_result) * 1., (db_result == {}) * 1., (self.db_result is None) * 1., bool(self.db_result) * 1., (self.db_result == {}) * 1., result_matches_state], dtype=np.float32) if self.debug: log.debug("Context features = {}".format(context_features)) debug_msg = "num bow features = {}, ".format(len(bow_features)) +\ "num emb features = {}, ".format(len(emb_features)) +\ "num intent features = {}, ".format(len(intent_features)) +\ "num state features = {}, ".format(len(state_features)) +\ "num context features = {}, ".format(len(context_features)) +\ "prev_action shape = {}".format(len(self.prev_action)) log.debug(debug_msg) concat_feats = np.hstack((bow_features, emb_features, intent_features, state_features, context_features, self.prev_action)) return concat_feats, emb_context, attn_key def _encode_response(self, act): return self.templates.actions.index(act) def _decode_response(self, action_id): """ Convert action template id and entities from tracker to final response. """ template = self.templates.templates[int(action_id)] slots = self.tracker.get_state() if self.db_result is not None: for k, v in self.db_result.items(): slots[k] = str(v) resp = template.generate_text(slots) # in api calls replace unknown slots to "dontcare" if (self.templates.ttype is templ.DualTemplate) and\ (action_id == self.api_call_id): resp = re.sub("#([A-Za-z]+)", "dontcare", resp).lower() if self.debug: log.debug("Pred response = {}".format(resp)) return resp def _action_mask(self, previous_action): mask = np.ones(self.n_actions, dtype=np.float32) if self.use_action_mask: known_entities = {**self.tracker.get_state(), **(self.db_result or {})} for a_id in range(self.n_actions): tmpl = str(self.templates.templates[a_id]) for entity in set(re.findall('#([A-Za-z]+)', tmpl)): if entity not in known_entities: mask[a_id] = 0. # forbid two api calls in a row if np.any(previous_action): prev_act_id = np.argmax(previous_action) if prev_act_id == self.api_call_id: mask[prev_act_id] = 0. return mask def train_on_batch(self, x, y): b_features, b_u_masks, b_a_masks, b_actions = [], [], [], [] b_emb_context, b_keys = [], [] # for attention max_num_utter = max(len(d_contexts) for d_contexts in x) for d_contexts, d_responses in zip(x, y): self.reset() if self.debug: preds = self._infer_dialog(d_contexts) d_features, d_a_masks, d_actions = [], [], [] d_emb_context, d_key = [], [] # for attention for context, response in zip(d_contexts, d_responses): if context.get('db_result') is not None: self.db_result = context['db_result'] features, emb_context, key = \ self._encode_context(context['text'], context.get('db_result')) d_features.append(features) d_emb_context.append(emb_context) d_key.append(key) d_a_masks.append(self._action_mask(self.prev_action)) action_id = self._encode_response(response['act']) d_actions.append(action_id) # previous action is teacher-forced here self.prev_action *= 0. self.prev_action[action_id] = 1. if self.debug: log.debug("True response = `{}`".format(response['text'])) if preds[0].lower() != response['text'].lower(): log.debug("Pred response = `{}`".format(preds[0])) preds = preds[1:] if d_a_masks[-1][action_id] != 1.: log.warn("True action forbidden by action mask.") # padding to max_num_utter num_padds = max_num_utter - len(d_contexts) d_features.extend([np.zeros_like(d_features[0])] * num_padds) d_emb_context.extend([np.zeros_like(d_emb_context[0])] * num_padds) d_key.extend([np.zeros_like(d_key[0])] * num_padds) d_u_mask = [1] * len(d_contexts) + [0] * num_padds d_a_masks.extend([np.zeros_like(d_a_masks[0])] * num_padds) d_actions.extend([0] * num_padds) b_features.append(d_features) b_emb_context.append(d_emb_context) b_keys.append(d_key) b_u_masks.append(d_u_mask) b_a_masks.append(d_a_masks) b_actions.append(d_actions) self.network.train_on_batch(b_features, b_emb_context, b_keys, b_u_masks, b_a_masks, b_actions) def _infer(self, context, db_result=None, prob=False): if db_result is not None: self.db_result = db_result features, emb_context, key = self._encode_context(context, db_result) action_mask = self._action_mask(self.prev_action) probs = self.network( [[features]], [[emb_context]], [[key]], [[action_mask]], prob=True ) pred_id = np.argmax(probs) # one-hot encoding seems to work better then probabilities if prob: self.prev_action = probs else: self.prev_action *= 0 self.prev_action[pred_id] = 1 return self._decode_response(pred_id) def _infer_dialog(self, contexts): self.reset() res = [] for context in contexts: if context.get('prev_resp_act') is not None: action_id = self._encode_response(context.get('prev_resp_act')) # previous action is teacher-forced self.prev_action *= 0. self.prev_action[action_id] = 1. res.append(self._infer(context['text'], context.get('db_result'))) return res def make_api_call(self, slots): db_results = [] if self.database is not None: # filter slot keys with value equal to 'dontcare' as # there is no such value in database records # and remove unknown slot keys (for example, 'this' in dstc2 tracker) db_slots = {s: v for s, v in slots.items() if (v != 'dontcare') and (s in self.database.keys)} db_results = self.database([db_slots])[0] else: log.warn("No database specified.") log.info("Made api_call with {}, got {} results.".format(slots, len(db_results))) # filter api results if there are more than one if len(db_results) > 1: db_results = [r for r in db_results if r != self.db_result] return db_results[0] if db_results else {} def __call__(self, batch): if isinstance(batch[0], str): res = [] for x in batch: pred = self._infer(x) # if made api_call, then respond with next prediction prev_act_id = np.argmax(self.prev_action) if prev_act_id == self.api_call_id: db_result = self.make_api_call(self.tracker.get_state()) res.append(self._infer(x, db_result=db_result)) else: res.append(pred) return res return [self._infer_dialog(x) for x in batch] def reset(self): self.tracker.reset_state() self.db_result = None self.prev_action = np.zeros(self.n_actions, dtype=np.float32) self.network.reset_state() if self.debug: log.debug("Bot reset.") def process_event(self, *args, **kwargs): self.network.process_event(*args, **kwargs)
[docs] def save(self): """Save the parameters of the model to a file.""" self.network.save()
def shutdown(self): self.network.shutdown() self.slot_filler.shutdown() def load(self): pass