Source code for deeppavlov.dataset_readers.dstc2_reader

# Copyright 2017 Neural Networks and Deep Learning lab, MIPT
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, softwaredata
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


import copy
import json
from logging import getLogger
from pathlib import Path
from typing import Dict, List

from overrides import overrides

from deeppavlov.core.common.registry import register
from deeppavlov.core.data.dataset_reader import DatasetReader
from deeppavlov.core.data.utils import download_decompress, mark_done

log = getLogger(__name__)


[docs]@register('dstc2_reader') class DSTC2DatasetReader(DatasetReader): """ Contains labelled dialogs from Dialog State Tracking Challenge 2 (http://camdial.org/~mh521/dstc/). There've been made the following modifications to the original dataset: 1. added api calls to restaurant database - example: ``{"text": "api_call area=\"south\" food=\"dontcare\" pricerange=\"cheap\"", "dialog_acts": ["api_call"]}``. 2. new actions - bot dialog actions were concatenated into one action (example: ``{"dialog_acts": ["ask", "request"]}`` -> ``{"dialog_acts": ["ask_request"]}``) - if a slot key was associated with the dialog action, the new act was a concatenation of an act and a slot key (example: ``{"dialog_acts": ["ask"], "slot_vals": ["area"]}`` -> ``{"dialog_acts": ["ask_area"]}``) 3. new train/dev/test split - original dstc2 consisted of three different MDP policies, the original train and dev datasets (consisting of two policies) were merged and randomly split into train/dev/test 4. minor fixes - fixed several dialogs, where actions were wrongly annotated - uppercased first letter of bot responses - unified punctuation for bot responses """ url = 'http://files.deeppavlov.ai/datasets/dstc2_v2.tar.gz' @staticmethod def _data_fname(datatype): assert datatype in ('trn', 'val', 'tst'), "wrong datatype name" return 'dstc2-{}.jsonlist'.format(datatype)
[docs] @classmethod @overrides def read(self, data_path: str, dialogs: bool = False) -> Dict[str, List]: """ Downloads ``'dstc2_v2.tar.gz'`` archive from ipavlov internal server, decompresses and saves files to ``data_path``. Parameters: data_path: path to save DSTC2 dataset dialogs: flag which indicates whether to output list of turns or list of dialogs Returns: dictionary that contains ``'train'`` field with dialogs from ``'dstc2-trn.jsonlist'``, ``'valid'`` field with dialogs from ``'dstc2-val.jsonlist'`` and ``'test'`` field with dialogs from ``'dstc2-tst.jsonlist'``. Each field is a list of tuples ``(x_i, y_i)``. """ required_files = (self._data_fname(dt) for dt in ('trn', 'val', 'tst')) if not all(Path(data_path, f).exists() for f in required_files): log.info('[downloading data from {} to {}]'.format(self.url, data_path)) download_decompress(self.url, data_path) mark_done(data_path) data = { 'train': self._read_from_file( Path(data_path, self._data_fname('trn')), dialogs), 'valid': self._read_from_file( Path(data_path, self._data_fname('val')), dialogs), 'test': self._read_from_file( Path(data_path, self._data_fname('tst')), dialogs) } return data
@classmethod def _read_from_file(cls, file_path, dialogs=False): """Returns data from single file""" log.info("[loading dialogs from {}]".format(file_path)) utterances, responses, dialog_indices =\ cls._get_turns(cls._iter_file(file_path), with_indices=True) data = list(map(cls._format_turn, zip(utterances, responses))) if dialogs: return [data[idx['start']:idx['end']] for idx in dialog_indices] return data @staticmethod def _format_turn(turn): x = {'text': turn[0]['text'], 'intents': turn[0]['dialog_acts']} if turn[0].get('db_result') is not None: x['db_result'] = turn[0]['db_result'] if turn[0].get('episode_done'): x['episode_done'] = True y = {'text': turn[1]['text'], 'act': turn[1]['dialog_acts'][0]['act']} return (x, y) @staticmethod def _iter_file(file_path): for ln in open(file_path, 'rt', encoding='utf8'): if ln.strip(): yield json.loads(ln) else: yield {} @staticmethod def _get_turns(data, with_indices=False): utterances = [] responses = [] dialog_indices = [] n = 0 num_dialog_utter, num_dialog_resp = 0, 0 episode_done = True for turn in data: if not turn: if num_dialog_utter != num_dialog_resp: raise RuntimeError("Datafile in the wrong format.") episode_done = True n += num_dialog_utter dialog_indices.append({ 'start': n - num_dialog_utter, 'end': n, }) num_dialog_utter, num_dialog_resp = 0, 0 else: speaker = turn.pop('speaker') if speaker == 1: if episode_done: turn['episode_done'] = True utterances.append(turn) num_dialog_utter += 1 elif speaker == 2: if num_dialog_utter - 1 == num_dialog_resp: responses.append(turn) elif num_dialog_utter - 1 < num_dialog_resp: if episode_done: responses.append(turn) utterances.append({ "text": "", "dialog_acts": [], "episode_done": True} ) else: new_turn = copy.deepcopy(utterances[-1]) if 'db_result' not in responses[-1]: raise RuntimeError("Every api_call action should have" " db_result, turn = {}" .format(responses[-1])) new_turn['db_result'] = responses[-1].pop('db_result') utterances.append(new_turn) responses.append(turn) num_dialog_utter += 1 else: raise RuntimeError("there cannot be two successive turns of" " speaker 1") num_dialog_resp += 1 else: raise RuntimeError("Only speakers 1 and 2 are supported") episode_done = False if with_indices: return utterances, responses, dialog_indices return utterances, responses