Source code for deeppavlov.dataset_readers.kvret_reader

# Copyright 2017 Neural Networks and Deep Learning lab, MIPT
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# Unless required by applicable law or agreed to in writing, softwaredata
# distributed under the License is distributed on an "AS IS" BASIS,
# See the License for the specific language governing permissions and
# limitations under the License.

import json
from pathlib import Path
from typing import Dict, List

from overrides import overrides

from deeppavlov.core.common.registry import register
from import DatasetReader
from import download_decompress, mark_done
from deeppavlov.core.common.log import get_logger

log = get_logger(__name__)

[docs]@register('kvret_reader') class KvretDatasetReader(DatasetReader): """ A New Multi-Turn, Multi-Domain, Task-Oriented Dialogue Dataset. Stanford NLP released a corpus of 3,031 multi-turn dialogues in three distinct domains appropriate for an in-car assistant: calendar scheduling, weather information retrieval, and point-of-interest navigation. The dialogues are grounded through knowledge bases ensuring that they are versatile in their natural language without being completely free form. For details see """ url = '' @staticmethod def _data_fname(datatype): assert datatype in ('train', 'dev', 'test'), "wrong datatype name" return 'kvret_{}_public.json'.format(datatype)
[docs] @classmethod @overrides def read(self, data_path: str, dialogs: bool = False) -> Dict[str, List]: """ Downloads ``'kvrest_public.tar.gz'``, decompresses, saves files to ``data_path``. Parameters: data_path: path to save data dialogs: flag indices whether to output list of turns or list of dialogs Returns: dictionary with ``'train'`` containing dialogs from ``'kvret_train_public.json'``, ``'valid'`` containing dialogs from ``'kvret_valid_public.json'``, ``'test'`` containing dialogs from ``'kvret_test_public.json'``. Each fields is a list of tuples ``(x_i, y_i)``. """ required_files = (self._data_fname(dt) for dt in ('train', 'dev', 'test')) if not all(Path(data_path, f).exists() for f in required_files):'[downloading dstc2 from {} to {}]'.format(self.url, data_path)) download_decompress(self.url, data_path) mark_done(data_path) data = { 'train': self._read_from_file( Path(data_path, self._data_fname('train')), dialogs), 'valid': self._read_from_file( Path(data_path, self._data_fname('dev')), dialogs), 'test': self._read_from_file( Path(data_path, self._data_fname('test')), dialogs) } return data
@classmethod def _read_from_file(cls, file_path, dialogs=False): """Returns data from single file""""[loading dialogs from {}]".format(file_path)) utterances, responses, dialog_indices =\ cls._get_turns(cls._iter_file(file_path), with_indices=True) data = list(map(cls._format_turn, zip(utterances, responses))) if dialogs: return [data[idx['start']:idx['end']] for idx in dialog_indices] return data @staticmethod def _format_turn(turn): x = {'text': turn[0]['utterance'], 'dialog_id': turn[0]['dialog_id'], 'kb_columns': turn[0]['kb_columns'], 'kb_items': turn[0]['kb_items'], 'requested': turn[0].get('requested', {}), 'slots': turn[0].get('slots', {})} if turn[0].get('episode_done') is not None: x['episode_done'] = turn[0]['episode_done'] y = {'text': turn[1]['utterance'], 'task': turn[0]['task'], 'requested': turn[1].get('requested', {}), 'slots': turn[1].get('slots', {})} return (x, y) @staticmethod def _check_dialog(dialog): #TODO: manually fix bad dialogs driver = True for turn in dialog: if turn['turn'] not in ('driver', 'assistant'): raise RuntimeError("Dataset wrong format: `turn` key value is" " either `driver` or `assistant`.") if driver and turn['turn'] != 'driver': log.debug("Turn is expected to by driver's, but it's {}'s"\ .format(turn['turn'])) return False if not driver and turn['turn'] != 'assistant': log.debug("Turn is expected to be assistant's but it's {}'s"\ .format(turn['turn'])) return False driver = not driver #if not driver: # log.debug("Last turn is expected to be by assistant") # return False return True @staticmethod def _filter_duplicates(dialog): last_turn, last_utter = None, None for turn in dialog: curr_turn, curr_utter = turn['turn'], turn['data']['utterance'] if (curr_turn != last_turn) or (curr_utter != last_utter): yield turn last_turn, last_utter = curr_turn, curr_utter @classmethod def _iter_file(cls, file_path): with open(file_path, 'rt', encoding='utf8') as f: data = json.load(f) for i, sample in enumerate(data): dialog = list(cls._filter_duplicates(sample['dialogue'])) if cls._check_dialog(dialog): yield dialog, sample['scenario'] else: log.warn("Skipping {}th dialogue with uuid={}: wrong format."\ .format(i, sample['scenario']['uuid'])) @staticmethod def _get_turns(data, with_indices=False): utterances, responses, dialog_indices = [], [], [] for dialog, scenario in data: for i, turn in enumerate(dialog): replica = turn['data'] if i == 0: replica['episode_done'] = True if turn['turn'] == 'driver': replica['task'] = scenario['task'] replica['dialog_id'] = scenario['uuid'] replica['kb_columns'] = scenario['kb']['column_names'] replica['kb_items'] = scenario['kb']['items'] utterances.append(replica) else: responses.append(replica) # if last replica was by driver if len(responses) != len(utterances): utterances[-1]['end_dialogue'] = False responses.append({'utterance': '', 'end_dialogue': True}) last_utter = responses[-1]['utterance'] if last_utter and not last_utter[-1].isspace(): last_utter += ' ' responses[-1]['utterance'] = last_utter + 'END_OF_DIALOGUE' dialog_indices.append({ 'start': len(utterances), 'end': len(utterances) + len(dialog), }) if with_indices: return utterances, responses, dialog_indices return utterances, responses