Source code for deeppavlov.models.kbqa.query_generator_base

# Copyright 2017 Neural Networks and Deep Learning lab, MIPT
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# See the License for the specific language governing permissions and
# limitations under the License.

import itertools
import json
from logging import getLogger
from typing import Tuple, List, Dict, Optional, Union, Any, Set

from bs4 import BeautifulSoup
from whapi import search, get_html

from deeppavlov.core.commands.utils import expand_path
from deeppavlov.core.common.file import read_json
from deeppavlov.core.models.component import Component
from deeppavlov.core.models.serializable import Serializable
from deeppavlov.models.entity_extraction.entity_linking import EntityLinker
from deeppavlov.models.kbqa.rel_ranking_infer import RelRankerInfer
from deeppavlov.models.kbqa.template_matcher import TemplateMatcher
from deeppavlov.models.kbqa.utils import preprocess_template_queries

log = getLogger(__name__)

[docs]class QueryGeneratorBase(Component, Serializable): """ This class takes as input entity substrings, defines the template of the query and fills the slots of the template with candidate entities and relations. """
[docs] def __init__(self, template_matcher: TemplateMatcher, rel_ranker: RelRankerInfer, load_path: str, sparql_queries_filename: str, entity_linker: EntityLinker, rels_in_ranking_queries_fname: str = None, wiki_parser=None, entities_to_leave: int = 5, rels_to_leave: int = 7, syntax_structure_known: bool = False, use_wp_api_requester: bool = False, use_el_api_requester: bool = False, use_alt_templates: bool = True, delete_rel_prefix: bool = True, kb_prefixes: Dict[str, str] = None, *args, **kwargs) -> None: """ Args: template_matcher: component deeppavlov.models.kbqa.template_matcher rel_ranker: component deeppavlov.models.kbqa.rel_ranking_infer load_path: path to folder with wikidata files sparql_queries_filename: file with sparql query templates entity_linker: component deeppavlov.models.entity_extraction.entity_linking for linking of entities rels_in_ranking_queries_fname: file with list of rels in queries for questions with ranking wiki_parser: component deeppavlov.models.kbqa.wiki_parser entities_to_leave: how many entities to leave after entity linking rels_to_leave: how many relations to leave after relation ranking syntax_structure_known: if syntax tree parser was used to define query template type use_wp_api_requester: whether deeppavlov.models.api_requester.api_requester component will be used for Wiki Parser use_el_api_requester: whether deeppavlov.models.api_requester.api_requester component will be used for Entity Linking use_alt_templates: whether to use alternative templates if no answer was found for default query template delete_rel_prefix: whether to delete prefix in relations kb_prefixes: prefixes for entities, relations and types in the knowledge base """ super().__init__(save_path=None, load_path=load_path) self.template_matcher = template_matcher self.entity_linker = entity_linker self.wiki_parser = wiki_parser self.rel_ranker = rel_ranker self.rels_in_ranking_queries_fname = rels_in_ranking_queries_fname self.rels_in_ranking_queries = {} self.entities_to_leave = entities_to_leave self.rels_to_leave = rels_to_leave self.syntax_structure_known = syntax_structure_known self.use_wp_api_requester = use_wp_api_requester self.use_el_api_requester = use_el_api_requester self.use_alt_templates = use_alt_templates self.sparql_queries_filename = sparql_queries_filename self.delete_rel_prefix = delete_rel_prefix self.kb_prefixes = kb_prefixes self.load()
def load(self) -> None: if self.rels_in_ranking_queries_fname is not None: self.rels_in_ranking_queries = read_json(self.load_path / self.rels_in_ranking_queries_fname) template_queries = read_json(str(expand_path(self.sparql_queries_filename))) self.template_queries = preprocess_template_queries(template_queries, self.kb_prefixes) def save(self) -> None: pass def find_candidate_answers(self, question: str, question_sanitized: str, template_types: Union[List[str], str], entities_from_ner: List[str], types_from_ner: List[str], entity_tags: List[str], probas: List[float], entities_to_link: List[int], answer_types: Set[str]) -> Tuple[Union[List[Dict[str, Any]], list], str]: candidate_outputs = [] self.template_nums = [template_types] replace_tokens = [(' - ', '-'), (' .', ''), ('{', ''), ('}', ''), (' ', ' '), ('"', "'"), ('(', ''), (')', ''), ('–', '-')] for old, new in replace_tokens: question = question.replace(old, new) entities_from_template, types_from_template, rels_from_template, rel_dirs_from_template, query_type_template, \ entity_types, template_answer, template_answer_types, template_found = self.template_matcher( question_sanitized, entities_from_ner) if query_type_template: self.template_nums = [query_type_template] log.debug( f"question: {question} entities_from_template {entities_from_template} template_type {self.template_nums} " f"types from template {types_from_template} rels_from_template {rels_from_template} entities_from_ner " f"{entities_from_ner} types_from_ner {types_from_ner} answer_types {list(answer_types)[:3]}") if entities_from_template or types_from_template: if rels_from_template[0][0] == "PHOW": how_to_content = self.find_answer_wikihow(entities_from_template[0]) candidate_outputs = [["PHOW", how_to_content, 1.0]] else: entity_ids = self.get_entity_ids(entities_from_template, entity_tags, probas, question, entities_to_link) type_ids = self.get_entity_ids(types_from_template, ["t" for _ in types_from_template], [1.0 for _ in types_from_template], question) log.debug(f"entities_from_template: {entities_from_template} --- entity_types: {entity_types} --- " f"types_from_template: {types_from_template} --- rels_from_template: {rels_from_template} " f"--- answer_types: {template_answer_types} --- entity_ids: {entity_ids}") candidate_outputs = self.sparql_template_parser(question_sanitized, entity_ids, type_ids, template_answer_types, rels_from_template, rel_dirs_from_template) if not candidate_outputs and (entities_from_ner or types_from_ner): log.debug(f"(__call__)entities_from_ner: {entities_from_ner}") entity_ids = self.get_entity_ids(entities_from_ner, entity_tags, probas, question) type_ids = self.get_entity_ids(types_from_ner, ["t" for _ in types_from_ner], [1.0 for _ in types_from_ner], question) log.debug(f"(__call__)entity_ids: {entity_ids} type_ids {type_ids}") self.template_nums = template_types log.debug(f"(__call__)self.template_nums: {self.template_nums}") if not self.syntax_structure_known: entity_ids = entity_ids[:3] candidate_outputs = self.sparql_template_parser(question_sanitized, entity_ids, type_ids, answer_types) return candidate_outputs, template_answer def get_entity_ids(self, entities: List[str], tags: List[str], probas: List[float], question: str, entities_to_link: List[int] = None) -> List[List[str]]: entity_ids, el_output = [], [] try: el_output = self.entity_linker([entities], [tags], [probas], [[question]], [None], [None], [entities_to_link]) except json.decoder.JSONDecodeError: log.warning("not received output from entity linking") if el_output: if self.use_el_api_requester: el_output = el_output[0] if el_output: if isinstance(el_output[0], dict): entity_ids = [entity_info.get("entity_ids", []) for entity_info in el_output] if isinstance(el_output[0], list): entity_ids, *_ = el_output if not self.use_el_api_requester and entity_ids: entity_ids = entity_ids[0] return entity_ids def sparql_template_parser(self, question: str, entity_ids: List[List[str]], type_ids: List[List[str]], answer_types: Set[str], rels_from_template: Optional[List[Tuple[str]]] = None, rel_dirs_from_template: Optional[List[str]] = None) -> Union[List[Dict[str, Any]], list]: candidate_outputs = [] if isinstance(self.template_nums, str): self.template_nums = [self.template_nums] template_log_list = [str([elem["query_template"], elem["template_num"]]) for elem in self.template_queries.values() if elem["template_num"] in self.template_nums] log.debug(f"(find_candidate_answers)self.template_nums: {' --- '.join(template_log_list)}") init_templates = [] for template_num in self.template_nums: for num, template in self.template_queries.items(): if (num == template_num and self.syntax_structure_known) or \ (template["template_num"] == template_num and not self.syntax_structure_known): init_templates.append(template) templates = [template for template in init_templates if (not self.syntax_structure_known and [len(entity_ids), len(type_ids)] == template[ "entities_and_types_num"]) or self.syntax_structure_known] if not templates: templates = [template for template in init_templates if (not self.syntax_structure_known and [len(entity_ids), 0] == template[ "entities_and_types_num"]) or self.syntax_structure_known] if not templates: return candidate_outputs if rels_from_template is not None: query_template = {} for template in templates: if template["rel_dirs"] == rel_dirs_from_template: query_template = template if query_template: candidate_outputs = self.query_parser(question, [query_template], entity_ids, type_ids, answer_types, rels_from_template) else: candidate_outputs = [] for priority in range(1, 3): pr_templates = [template for template in templates if template["priority"] == priority] candidate_outputs = self.query_parser(question, pr_templates, entity_ids, type_ids, answer_types, rels_from_template) if candidate_outputs: return candidate_outputs if not candidate_outputs: alt_template_nums = templates[0].get("alternative_templates", []) log.debug(f"Using alternative templates {alt_template_nums}") alt_templates = [self.template_queries[num] for num in alt_template_nums] candidate_outputs = self.query_parser(question, alt_templates, entity_ids, type_ids, answer_types, rels_from_template) if candidate_outputs: return candidate_outputs log.debug("candidate_rels_and_answers:\n" + '\n'.join([str(output) for output in candidate_outputs[:5]])) return candidate_outputs def find_top_rels(self, question: str, entity_ids: List[List[str]], triplet_info: Tuple) -> \ Tuple[List[Tuple[str, float]], Dict[str, float], Set[Tuple[str, str]]]: ex_rels, entity_rel_conn = [], set() direction, source, rel_type, n_hop = triplet_info if source == "wiki": queries_list = list({(entity, direction, rel_type) for entity_id in entity_ids for entity in entity_id[:self.entities_to_leave]}) entity_ids_list = [elem[0] for elem in queries_list] parser_info_list = ["find_rels" for i in range(len(queries_list))] ex_rels = self.wiki_parser(parser_info_list, queries_list) for ex_rels_elem, entity_id in zip(ex_rels, entity_ids_list): for rel in ex_rels_elem: entity_rel_conn.add((entity_id, rel.split("/")[-1])) if self.use_wp_api_requester and ex_rels: ex_rels = [rel[0] for rel in ex_rels] ex_rels = list(set(itertools.chain.from_iterable(ex_rels))) if n_hop in {"1-of-2-hop", "2-hop"}: queries_list = list({(entity, "backw", rel_type) for entity_id in entity_ids for entity in entity_id[:self.entities_to_leave]}) entity_ids_list = [elem[0] for elem in queries_list] parser_info_list = ["find_rels" for i in range(len(queries_list))] ex_rels_backw = self.wiki_parser(parser_info_list, queries_list) for ex_rels_elem, entity_id in zip(ex_rels_backw, entity_ids_list): for rel in ex_rels_elem: entity_rel_conn.add((entity_id, rel.split("/")[-1])) ex_rels_backw = list(set(itertools.chain.from_iterable(ex_rels_backw))) ex_rels += ex_rels_backw if self.delete_rel_prefix: ex_rels = [rel.split('/')[-1] for rel in ex_rels] elif source in {"rank_list_1", "rel_list_1"}: ex_rels = self.rels_in_ranking_queries.get("one_rel_in_query", []) elif source in {"rank_list_2", "rel_list_2"}: ex_rels = self.rels_in_ranking_queries.get("two_rels_in_query", []) ex_rels = [rel for rel in ex_rels if not any([rel.endswith(t_rel) for t_rel in self.kb_prefixes["type_rels"]])] rels_with_scores = self.rel_ranker.rank_rels(question, ex_rels) if n_hop == "2-hop" and rels_with_scores and entity_ids and entity_ids[0]: rels_1hop = [rel for rel, score in rels_with_scores] queries_list = [(entity_ids[0], rels_1hop[:5])] parser_info_list = ["find_rels_2hop"] ex_rels_2hop = self.wiki_parser(parser_info_list, queries_list) if self.delete_rel_prefix: ex_rels_2hop = [rel.split('/')[-1] for rel in ex_rels_2hop] rels_with_scores = self.rel_ranker.rank_rels(question, ex_rels_2hop) rels_with_scores = list(set(rels_with_scores)) rels_with_scores = sorted(rels_with_scores, key=lambda x: x[1], reverse=True) rels_scores_dict = {rel: score for rel, score in rels_with_scores} return rels_with_scores[:self.rels_to_leave], rels_scores_dict, entity_rel_conn def find_answer_wikihow(self, howto_sentence: str) -> str: tags = [] search_results = search(howto_sentence, 5) if search_results: article_id = search_results[0]["article_id"] html = get_html(article_id) page = BeautifulSoup(html, 'lxml') tags = list(page.find_all(['p'])) if tags: howto_content = f"{tags[0].text.strip()}@en" else: howto_content = "Not Found" return howto_content def query_parser(self, question, query_templates, entity_ids, type_ids, answer_types, rels_from_template): raise NotImplementedError