Source code for deeppavlov.models.kbqa.query_generator

# Copyright 2017 Neural Networks and Deep Learning lab, MIPT
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import itertools
import re
from collections import namedtuple, OrderedDict
from logging import getLogger
from typing import Tuple, List, Optional, Union, Dict, Any, Set

import nltk
import numpy as np

from deeppavlov.core.common.registry import register
from deeppavlov.models.kbqa.query_generator_base import QueryGeneratorBase
from deeppavlov.models.kbqa.rel_ranking_infer import RelRankerInfer
from deeppavlov.models.kbqa.utils import \
    extract_year, extract_number, order_of_answers_sorting, make_combs, fill_query
from deeppavlov.models.kbqa.wiki_parser import WikiParser

log = getLogger(__name__)


[docs]@register('query_generator') class QueryGenerator(QueryGeneratorBase): """ Class for query generation using Wikidata hdt file """
[docs] def __init__(self, wiki_parser: WikiParser, rel_ranker: RelRankerInfer, entities_to_leave: int = 5, rels_to_leave: int = 7, max_comb_num: int = 10000, return_all_possible_answers: bool = False, *args, **kwargs) -> None: """ Args: wiki_parser: component deeppavlov.models.kbqa.wiki_parser rel_ranker: component deeppavlov.models.kbqa.rel_ranking_infer entities_to_leave: how many entities to leave after entity linking rels_to_leave: how many relations to leave after relation ranking max_comb_num: the maximum number of combinations of candidate entities and relations return_all_possible_answers: whether to return all found answers **kwargs: """ self.wiki_parser = wiki_parser self.rel_ranker = rel_ranker self.entities_to_leave = entities_to_leave self.rels_to_leave = rels_to_leave self.max_comb_num = max_comb_num self.return_all_possible_answers = return_all_possible_answers self.replace_tokens = [("wdt:p31", "wdt:P31"), ("pq:p580", "pq:P580"), ("pq:p582", "pq:P582"), ("pq:p585", "pq:P585"), ("pq:p1545", "pq:P1545")] super().__init__(wiki_parser=self.wiki_parser, rel_ranker=self.rel_ranker, entities_to_leave=self.entities_to_leave, rels_to_leave=self.rels_to_leave, *args, **kwargs)
[docs] def __call__(self, question_batch: List[str], question_san_batch: List[str], template_type_batch: Union[List[List[str]], List[str]], entities_from_ner_batch: List[List[str]], entity_tags_batch: List[List[str]], answer_types_batch: List[Set[str]]) -> List[str]: candidate_outputs_batch = [] template_answers_batch = [] templates_nums_batch = [] log.debug(f"kbqa inputs {question_batch} {entities_from_ner_batch} {template_type_batch} {entity_tags_batch}") for question, question_sanitized, template_type, entities_from_ner, entity_tags_list, answer_types in \ zip(question_batch, question_san_batch, template_type_batch, entities_from_ner_batch, entity_tags_batch, answer_types_batch): if template_type == "-1": template_type = "7" candidate_outputs, template_answer, templates_nums = \ self.find_candidate_answers(question, question_sanitized, template_type, entities_from_ner, entity_tags_list, answer_types) candidate_outputs_batch.append(candidate_outputs) template_answers_batch.append(template_answer) templates_nums_batch.append(templates_nums) answers = self.rel_ranker(question_batch, candidate_outputs_batch, entities_from_ner_batch, template_answers_batch) log.debug(f"(__call__)answers: {answers}") if not answers: answers = ["Not Found" for _ in question_batch] return answers
def query_parser(self, question: str, query_info: Dict[str, str], entities_and_types_select: List[str], entity_ids: List[List[str]], type_ids: List[List[str]], answer_types: Set[str], rels_from_template: Optional[List[Tuple[str]]] = None) -> Union[ List[Dict[str, Union[Union[Tuple[Any, ...], List[Any]], Any]]], List[Dict[str, Any]]]: question_tokens = nltk.word_tokenize(question) query = query_info["query_template"].lower() for old_tok, new_tok in self.replace_tokens: query = query.replace(old_tok, new_tok) log.debug(f"\n_______________________________\nquery: {query}\n_______________________________\n") rels_for_search = query_info["rank_rels"] rel_types = query_info["rel_types"] query_seq_num = query_info["query_sequence"] return_if_found = query_info["return_if_found"] define_sorting_order = query_info["define_sorting_order"] property_types = query_info["property_types"] log.debug(f"(query_parser)query: {query}, {rels_for_search}, {query_seq_num}, {return_if_found}") query_triplets = re.findall("{[ ]?(.*?)[ ]?}", query)[0].split(' . ') log.debug(f"(query_parser)query_triplets: {query_triplets}") query_triplets = [triplet.split(' ')[:3] for triplet in query_triplets] query_sequence_dict = {num: triplet for num, triplet in zip(query_seq_num, query_triplets)} query_sequence = [] for i in range(1, max(query_seq_num) + 1): query_sequence.append(query_sequence_dict[i]) triplet_info_list = [("forw" if triplet[2].startswith('?') else "backw", search_source, rel_type) for search_source, triplet, rel_type in zip(rels_for_search, query_triplets, rel_types) if search_source != "do_not_rank"] log.debug(f"(query_parser)rel_directions: {triplet_info_list}") entity_ids = [entity[:self.entities_to_leave] for entity in entity_ids] if rels_from_template is not None: rels = [[(rel, 1.0) for rel in rel_list] for rel_list in rels_from_template] else: rels = [self.find_top_rels(question, entity_ids, triplet_info) for triplet_info in triplet_info_list] rels = [[rel for rel in rel_list] for rel_list in rels] log.debug(f"(query_parser)rels: {rels}") rels_from_query = [triplet[1] for triplet in query_triplets if triplet[1].startswith('?')] answer_ent = re.findall(r"select [\(]?([\S]+) ", query) order_info_nt = namedtuple("order_info", ["variable", "sorting_order"]) order_variable = re.findall("order by (asc|desc)\((.*)\)", query) if order_variable: if define_sorting_order: answers_sorting_order = order_of_answers_sorting(question) else: answers_sorting_order = order_variable[0][0] order_info = order_info_nt(order_variable[0][1], answers_sorting_order) else: order_info = order_info_nt(None, None) log.debug(f"question, order_info: {question}, {order_info}") filter_from_query = re.findall("contains\((\?\w), (.+?)\)", query) log.debug(f"(query_parser)filter_from_query: {filter_from_query}") year = extract_year(question_tokens, question) number = extract_number(question_tokens, question) log.debug(f"year {year}, number {number}") if year: filter_info = [(elem[0], elem[1].replace("n", year)) for elem in filter_from_query] elif number: filter_info = [(elem[0], elem[1].replace("n", number)) for elem in filter_from_query] else: filter_info = [elem for elem in filter_from_query if elem[1] != "n"] for unk_prop, prop_type in property_types.items(): filter_info.append((unk_prop, prop_type)) log.debug(f"(query_parser)filter_from_query: {filter_from_query}") rel_combs = make_combs(rels, permut=False) entity_positions, type_positions = [elem.split('_') for elem in entities_and_types_select.split(' ')] log.debug(f"entity_positions {entity_positions}, type_positions {type_positions}") selected_entity_ids, selected_type_ids = [], [] if entity_ids: selected_entity_ids = [entity_ids[int(pos) - 1] for pos in entity_positions if int(pos) > 0] if type_ids: selected_type_ids = [type_ids[int(pos) - 1] for pos in type_positions if int(pos) > 0] entity_combs = make_combs(selected_entity_ids, permut=True) type_combs = make_combs(selected_type_ids, permut=False) log.debug(f"(query_parser)entity_combs: {entity_combs[:3]}, type_combs: {type_combs[:3]}," f" rel_combs: {rel_combs[:3]}") queries_list = [] parser_info_list = [] confidences_list = [] all_combs_list = list(itertools.product(entity_combs, type_combs, rel_combs)) for comb_num, combs in enumerate(all_combs_list): confidence = np.prod([score for rel, score in combs[2][:-1]]) confidences_list.append(confidence) query_hdt_seq = [ fill_query(query_hdt_elem, combs[0], combs[1], combs[2]) for query_hdt_elem in query_sequence] if comb_num == 0: log.debug(f"\n__________________________\nfilled query: {query_hdt_seq}\n__________________________\n") if comb_num > 0: answer_types = [] queries_list.append( (rels_from_query + answer_ent, query_hdt_seq, filter_info, order_info, answer_types, rel_types, return_if_found)) parser_info_list.append("query_execute") if comb_num == self.max_comb_num: break candidate_outputs = [] candidate_outputs_list = self.wiki_parser(parser_info_list, queries_list) if self.use_wp_api_requester and isinstance(candidate_outputs_list, list) and candidate_outputs_list: candidate_outputs_list = candidate_outputs_list[0] if isinstance(candidate_outputs_list, list) and candidate_outputs_list: outputs_len = len(candidate_outputs_list) all_combs_list = all_combs_list[:outputs_len] confidences_list = confidences_list[:outputs_len] for combs, confidence, candidate_output in zip(all_combs_list, confidences_list, candidate_outputs_list): candidate_outputs += [[combs[0]] + [rel for rel, score in combs[2][:-1]] + output + [confidence] for output in candidate_output] if self.return_all_possible_answers: candidate_outputs_dict = OrderedDict() for candidate_output in candidate_outputs: candidate_output_key = (tuple(candidate_output[0]), tuple(candidate_output[1:-2])) if candidate_output_key not in candidate_outputs_dict: candidate_outputs_dict[candidate_output_key] = [] candidate_outputs_dict[candidate_output_key].append(candidate_output[-2:]) candidate_outputs = [] for (candidate_entity_comb, candidate_rel_comb), candidate_output in candidate_outputs_dict.items(): candidate_outputs.append({"entities": candidate_entity_comb, "relations": list(candidate_rel_comb), "answers": tuple([ans for ans, conf in candidate_output]), "rel_conf": candidate_output[0][1] }) else: candidate_outputs = [{"entities": f_entities, "relations": f_relations, "answers": f_answers, "rel_conf": f_rel_conf } for f_entities, *f_relations, f_answers, f_rel_conf in candidate_outputs] log.debug(f"(query_parser)final outputs: {candidate_outputs[:3]}") return candidate_outputs