Source code for deeppavlov.models.kbqa.query_generator_online

# Copyright 2017 Neural Networks and Deep Learning lab, MIPT
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import itertools
import re
from logging import getLogger
from typing import Tuple, List, Optional, Union, Dict, Any

import nltk

from deeppavlov.core.common.registry import register
from deeppavlov.models.kbqa.wiki_parser_online import WikiParserOnline
from deeppavlov.models.kbqa.rel_ranking_infer import RelRankerInfer
from deeppavlov.models.kbqa.rel_ranking_bert_infer import RelRankerBertInfer
from deeppavlov.models.kbqa.utils import \
    extract_year, extract_number, make_combs, fill_online_query
from deeppavlov.models.kbqa.query_generator_base import QueryGeneratorBase

log = getLogger(__name__)


[docs]@register('query_generator_online') class QueryGeneratorOnline(QueryGeneratorBase): """ Class for query generation online using Wikidata query service """
[docs] def __init__(self, wiki_parser: WikiParserOnline, rel_ranker: Union[RelRankerInfer, RelRankerBertInfer], entities_to_leave: int = 5, rels_to_leave: int = 7, return_answers: bool = False, *args, **kwargs) -> None: """ Args: wiki_parser: component deeppavlov.models.kbqa.wiki_parser rel_ranker: component deeppavlov.models.kbqa.rel_ranking_infer entities_to_leave: how many entities to leave after entity linking rels_to_leave: how many relations to leave after relation ranking return_answers: whether to return answers or candidate answers """ self.wiki_parser = wiki_parser self.rel_ranker = rel_ranker self.entities_to_leave = entities_to_leave self.rels_to_leave = rels_to_leave self.return_answers = return_answers super().__init__(wiki_parser=self.wiki_parser, rel_ranker=self.rel_ranker, entities_to_leave=self.entities_to_leave, rels_to_leave=self.rels_to_leave, return_answers=self.return_answers, *args, **kwargs) self.load()
[docs] def __call__(self, question_batch: List[str], question_san_batch: List[str], template_type_batch: List[str], entities_from_ner_batch: List[List[str]], types_from_ner_batch: List[List[str]]) -> List[Union[List[Tuple[str, Any]], List[str]]]: candidate_outputs_batch = [] for question, question_sanitized, template_type, entities_from_ner, types_from_ner in \ zip(question_batch, question_san_batch, template_type_batch, entities_from_ner_batch, types_from_ner_batch): candidate_outputs, _ = self.find_candidate_answers(question, question_sanitized, template_type, entities_from_ner, types_from_ner) candidate_outputs_batch.append(candidate_outputs) if self.return_answers: answers = self.rel_ranker(question_batch, candidate_outputs_batch) log.debug(f"(__call__)answers: {answers}") return answers else: log.debug(f"(__call__)candidate_outputs_batch: {[output[:5] for output in candidate_outputs_batch]}") return candidate_outputs_batch
def query_parser(self, question: str, query_info: Dict[str, str], entities_and_types_select: List[str], entity_ids: List[List[str]], type_ids: List[List[str]], rels_from_template: Optional[List[Tuple[str]]] = None) -> List[Tuple[str]]: question_tokens = nltk.word_tokenize(question) query = query_info["query_template"].lower().replace("wdt:p31", "wdt:P31") rels_for_search = query_info["rank_rels"] rel_types = query_info["rel_types"] rels_for_filter = query_info["filter_rels"] property_types = query_info["property_types"] query_seq_num = query_info["query_sequence"] return_if_found = query_info["return_if_found"] log.debug(f"(query_parser)query: {query}, {rels_for_search}, {query_seq_num}, {return_if_found}") query_triplets = re.findall("{[ ]?(.*?)[ ]?}", query)[0].split(' . ') log.debug(f"(query_parser)query_triplets: {query_triplets}") query_triplets = [triplet.split(' ')[:3] for triplet in query_triplets] triplet_info_list = [("forw" if triplet[2].startswith('?') else "backw", search_source, rel_type) for search_source, triplet, rel_type in zip(rels_for_search, query_triplets, rel_types) if search_source != "do_not_rank"] log.debug(f"(query_parser)rel_directions: {triplet_info_list}") rel_variables = re.findall(":(r[\d]{1,2})", query) entity_ids = [entity[:self.entities_to_leave] for entity in entity_ids] if rels_from_template is not None: rels = [[(rel, 1.0) for rel in rel_list] for rel_list in rels_from_template] else: rels = [self.find_top_rels(question, entity_ids, triplet_info) for triplet_info in triplet_info_list] rels_list_for_filter = [] rels_list_for_fill = [] filter_rel_variables = [] fill_rel_variables = [] for rel_variable, rel_list, is_filter in zip(rel_variables, rels, rels_for_filter): if is_filter: rels_list_for_filter.append(rel_list) filter_rel_variables.append(rel_variable) else: rels_list_for_fill.append(rel_list) fill_rel_variables.append(rel_variable) log.debug(f"(query_parser)rels: {rels}") log.debug(f"rel_variables {rel_variables}, filter_rel_variables: {filter_rel_variables}") log.debug(f"rels_list_for_filter: {rels_list_for_filter}") log.debug(f"rels_list_for_fill: {rels_list_for_fill}") rels_from_query = list(set([triplet[1] for triplet in query_triplets if triplet[1].startswith('?')])) if "count" in query: answer_ent = re.findall("as (\?[\S]+)", query) else: answer_ent = re.findall("select [\(]?([\S]+) ", query) filter_from_query = re.findall("contains\((\?\w), (.+?)\)", query) log.debug(f"(query_parser)filter_from_query: {filter_from_query}") year = extract_year(question_tokens, question) number = extract_number(question_tokens, question) log.debug(f"year {year}, number {number}") if year: for elem in filter_from_query: query = query.replace(f"{elem[0]}, n", f"YEAR({elem[0]}), {year}") elif number: for elem in filter_from_query: query = query.replace(f"{elem[0]}, n", f"{elem[0]}, {number}") query = query.replace(" where", f" {' '.join(rels_from_query)} where") log.debug(f"(query_parser)query_with_filtering: {query}") rel_combs = make_combs(rels_list_for_fill, permut=False) log.debug(f"(query_parser)rel_combs: {rel_combs[:3]}") import datetime start_time = datetime.datetime.now() entity_positions, type_positions = [elem.split('_') for elem in entities_and_types_select.split(' ')] log.debug(f"entity_positions {entity_positions}, type_positions {type_positions}") selected_entity_ids = [entity_ids[int(pos) - 1] for pos in entity_positions if int(pos) > 0] selected_type_ids = [type_ids[int(pos) - 1] for pos in type_positions if int(pos) > 0] entity_combs = make_combs(selected_entity_ids, permut=True) log.debug(f"(query_parser)entity_combs: {entity_combs[:3]}") type_combs = make_combs(selected_type_ids, permut=False) log.debug(f"(query_parser)type_combs: {type_combs[:3]}") confidence = 0.0 queries_list = [] parser_info_list = [] all_combs_list = list(itertools.product(entity_combs, type_combs, rel_combs)) for comb_num, combs in enumerate(all_combs_list): filled_query, filter_rels = fill_online_query(query, combs[0], combs[1], combs[2], fill_rel_variables, filter_rel_variables, rels_list_for_filter) if comb_num == 0: log.debug(f"\n___________________________\nfilled query: {filled_query}\n___________________________\n") queries_list.append((filled_query, return_if_found)) parser_info_list.append("query_execute") candidate_outputs_list = self.wiki_parser(parser_info_list, queries_list) outputs_len = len(candidate_outputs_list) all_combs_list = all_combs_list[:outputs_len] out_vars = filter_rels + rels_from_query + answer_ent candidate_outputs = [] for combs, candidate_output in zip(all_combs_list, candidate_outputs_list): candidate_output = [output for output in candidate_output if (all([filter_value in output[filter_var[1:]]["value"] for filter_var, filter_value in property_types.items()]) and all([not output[ent[1:]]["value"].startswith("http://www.wikidata.org/value") for ent in answer_ent]))] candidate_outputs += [combs[2][:-1] + [output[var[1:]]["value"] for var in out_vars] + [confidence] for output in candidate_output] log.debug(f"(query_parser)loop time: {datetime.datetime.now() - start_time}") log.debug(f"(query_parser)final outputs: {candidate_outputs[:3]}") return candidate_outputs