Source code for deeppavlov.models.kbqa.type_define

# Copyright 2017 Neural Networks and Deep Learning lab, MIPT
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import pickle
from typing import List

import spacy
from nltk.corpus import stopwords

from deeppavlov.core.commands.utils import expand_path
from deeppavlov.core.common.registry import register


[docs]@register('answer_types_extractor') class AnswerTypesExtractor: """Class which defines answer types for the question"""
[docs] def __init__(self, lang: str, types_filename: str, types_sets_filename: str, num_types_to_return: int = 15, **kwargs): """ Args: lang: Russian or English types_filename: filename with dictionary where keys are type ids and values are type labels types_sets_filename: filename with dictionary where keys are NER tags and values are Wikidata types corresponding to tags num_types_to_return: how many answer types to return for each question **kwargs: """ self.lang = lang self.types_filename = str(expand_path(types_filename)) self.types_sets_filename = str(expand_path(types_sets_filename)) self.num_types_to_return = num_types_to_return if self.lang == "@en": self.stopwords = set(stopwords.words("english")) self.nlp = spacy.load("en_core_web_sm") self.pronouns = ["what"] elif self.lang == "@ru": self.stopwords = set(stopwords.words("russian")) self.nlp = spacy.load("ru_core_news_sm") self.pronouns = ["какой", "каком"] with open(self.types_filename, 'rb') as fl: self.types_dict = pickle.load(fl) with open(self.types_sets_filename, 'rb') as fl: self.types_sets = pickle.load(fl)
[docs] def __call__(self, questions_batch: List[str], entity_substr_batch: List[List[str]], tags_batch: List[List[str]], types_substr_batch: List[List[str]] = None): if types_substr_batch is None: types_substr_batch = [] for question, entity_substr_list in zip(questions_batch, entity_substr_batch): types_substr = [] type_noun = "" doc = self.nlp(question) token_pos_dict = {} for n, token in enumerate(doc): token_pos_dict[token.text] = n for token in doc: if token.text.lower() in self.pronouns and token.head.dep_ in ["attr", "nsubj"]: type_noun = token.head.text if not any([type_noun in entity_substr.lower() for entity_substr in entity_substr_list]): types_substr.append(type_noun) break if type_noun: for token in doc: if token.head.text == type_noun and token.dep_ in ["amod", "compound"]: type_adj = token.text if not any([type_adj.lower() in entity_substr.lower() for entity_substr in entity_substr_list]): types_substr.append(type_adj) break elif token.head.text == type_noun and token.dep_ == "prep": if len(list(token.children)) == 1 \ and not any([list(token.children)[0].text in entity_substr.lower() for entity_substr in entity_substr_list]): types_substr += [token.text, list(token.children)[0].text] elif any([word in question for word in self.pronouns]): for token in doc: if token.dep_ == "nsubj" and not any([token.text in entity_substr.lower() for entity_substr in entity_substr_list]): types_substr.append(token.text) types_substr = [(token, token_pos_dict[token]) for token in types_substr] types_substr = sorted(types_substr, key=lambda x: x[1]) types_substr = " ".join([elem[0] for elem in types_substr]) types_substr_batch.append(types_substr) types_sets_batch = [set() for _ in questions_batch] for n, (question, types_sets) in enumerate(zip(questions_batch, types_sets_batch)): question = question.lower() if not types_sets: if self.lang == "@ru": if question.startswith("кто"): types_sets_batch[n] = self.types_sets["PER"] elif question.startswith("где"): types_sets_batch[n] = self.types_sets["LOC"] elif any([question.startswith(elem) for elem in ["когда", "в каком году", "в каком месяце"]]): types_sets_batch[n] = {"date"} elif len(question.split()) > 1 and (any([question.startswith(elem) for elem in ["кем ", "как"]]) \ or question.split()[1].startswith("как")): types_sets_batch[n] = {"not_date"} elif self.lang == "@en": if question.startswith("who"): types_sets_batch[n] = self.types_sets["PER"] elif question.startswith("where"): types_sets_batch[n] = self.types_sets["LOC"] elif any([question.startswith(elem) for elem in ["when", "what year", "what month"]]): types_sets_batch[n] = {"date"} new_entity_substr_batch, new_entity_offsets_batch, new_tags_batch = [], [], [] for question, entity_substr_list, tags_list in zip(questions_batch, entity_substr_batch, tags_batch): new_entity_substr, new_tags = [], [] if not entity_substr_list: doc = self.nlp(question) for token in doc: if token.dep_ == "nsubj": new_entity_substr.append(token.text) new_tags.append("MISC") break new_entity_substr_batch.append(new_entity_substr) new_tags_batch.append(new_tags) else: new_entity_substr_batch.append(entity_substr_list) new_tags_batch.append(tags_list) return types_sets_batch, new_entity_substr_batch, new_tags_batch