Source code for deeppavlov.dataset_iterators.squad_iterator

# Copyright 2017 Neural Networks and Deep Learning lab, MIPT
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


from typing import Dict, Any, List, Tuple

from deeppavlov.core.common.registry import register
from deeppavlov.core.data.data_learning_iterator import DataLearningIterator


[docs]@register('squad_iterator') class SquadIterator(DataLearningIterator): """ SquadIterator allows to iterate over examples in SQuAD-like datasets. SquadIterator is used to train :class:`~deeppavlov.models.squad.squad.SquadModel`. It extracts ``context``, ``question``, ``answer_text`` and ``answer_start`` position from dataset. Example from a dataset is a tuple of ``(context, question)`` and ``(answer_text, answer_start)`` Attributes: train: train examples valid: validation examples test: test examples """ def split(self, *args, **kwargs) -> None: for dt in ['train', 'valid', 'test']: setattr(self, dt, SquadIterator._extract_cqas(getattr(self, dt))) @staticmethod def _extract_cqas(data: Dict[str, Any]) -> List[Tuple[Tuple[str, str], Tuple[List[str], List[int]]]]: """ Extracts context, question, answer, answer_start from SQuAD data Args: data: data in squad format Returns: list of (context, question), (answer_text, answer_start) answer text and answer_start are lists """ cqas = [] if data: for article in data['data']: for par in article['paragraphs']: context = par['context'] for qa in par['qas']: q = qa['question'] ans_text = [] ans_start = [] for answer in qa['answers']: ans_text.append(answer['text']) ans_start.append(answer['answer_start']) cqas.append(((context, q), (ans_text, ans_start))) return cqas