Source code for deeppavlov.dataset_iterators.dstc2_ner_iterator

# Copyright 2017 Neural Networks and Deep Learning lab, MIPT
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# See the License for the specific language governing permissions and
# limitations under the License.

import json
import logging
from typing import List, Tuple, Dict, Any

from deeppavlov.core.commands.utils import expand_path
from deeppavlov.core.common.registry import register
from import DataLearningIterator

logger = logging.getLogger(__name__)

[docs]@register('dstc2_ner_iterator') class Dstc2NerDatasetIterator(DataLearningIterator): """ Iterates over data for DSTC2 NER task. Dataset takes a dict with fields 'train', 'test', 'valid'. A list of samples (pairs x, y) is stored in each field. Args: data: list of (x, y) pairs, samples from the dataset: x as well as y can be a tuple of different input features. dataset_path: path to dataset seed: value for random seed shuffle: whether to shuffle the data """ def __init__(self, data: Dict[str, List[Tuple]], slot_values_path: str, seed: int = None, shuffle: bool = False): # TODO: include slot vals to dstc2.tar.gz with expand_path(slot_values_path).open(encoding='utf8') as f: self._slot_vals = json.load(f) super().__init__(data, seed, shuffle) def preprocess(self, data: List[Tuple[Any, Any]], *args, **kwargs) -> List[Tuple[Any, Any]]: processed_data = list() processed_texts = dict() for x, y in data: text = x['text'] if not text.strip(): continue intents = [] if 'intents' in x: intents = x['intents'] elif 'slots' in x: intents = [x] # aggregate slots from different intents slots = list() for intent in intents: current_slots = intent.get('slots', []) for slot_type, slot_val in current_slots: if not self._slot_vals or (slot_type in self._slot_vals): slots.append((slot_type, slot_val,)) # remove duplicate pairs (text, slots) if (text in processed_texts) and (slots in processed_texts[text]): continue processed_texts[text] = processed_texts.get(text, []) + [slots] processed_data.append(self._add_bio_markup(text, slots)) return processed_data def _add_bio_markup(self, utterance: str, slots: List[Tuple[str, str]]) -> Tuple[List, List]: tokens = utterance.split() n_toks = len(tokens) tags = ['O' for _ in range(n_toks)] for n in range(n_toks): for slot_type, slot_val in slots: for entity in self._slot_vals[slot_type].get(slot_val, [slot_val]): slot_tokens = entity.split() slot_len = len(slot_tokens) if n + slot_len <= n_toks and \ self._is_equal_sequences(tokens[n: n + slot_len], slot_tokens): tags[n] = 'B-' + slot_type for k in range(1, slot_len): tags[n + k] = 'I-' + slot_type break return tokens, tags @staticmethod def _is_equal_sequences(seq1, seq2): equality_list = [tok1 == tok2 for tok1, tok2 in zip(seq1, seq2)] return all(equality_list)