Source code for deeppavlov.models.kbqa.query_generator_base

# Copyright 2017 Neural Networks and Deep Learning lab, MIPT
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from logging import getLogger
from typing import Tuple, List, Optional, Union, Any

from whapi import search, get_html
from bs4 import BeautifulSoup

from deeppavlov.core.models.component import Component
from deeppavlov.core.models.serializable import Serializable
from deeppavlov.core.common.file import read_json
from deeppavlov.core.commands.utils import expand_path
from deeppavlov.models.kbqa.template_matcher import TemplateMatcher
from deeppavlov.models.kbqa.entity_linking import EntityLinker
from deeppavlov.models.kbqa.rel_ranking_infer import RelRankerInfer
from deeppavlov.models.kbqa.rel_ranking_bert_infer import RelRankerBertInfer

log = getLogger(__name__)


[docs]class QueryGeneratorBase(Component, Serializable): """ This class takes as input entity substrings, defines the template of the query and fills the slots of the template with candidate entities and relations. """
[docs] def __init__(self, template_matcher: TemplateMatcher, linker_entities: EntityLinker, linker_types: EntityLinker, rel_ranker: Union[RelRankerInfer, RelRankerBertInfer], load_path: str, rank_rels_filename_1: str, rank_rels_filename_2: str, sparql_queries_filename: str, wiki_parser=None, wiki_file_format: str = "hdt", entities_to_leave: int = 5, rels_to_leave: int = 7, syntax_structure_known: bool = False, use_api_requester: bool = False, return_answers: bool = False, *args, **kwargs) -> None: """ Args: template_matcher: component deeppavlov.models.kbqa.template_matcher linker_entities: component deeppavlov.models.kbqa.entity_linking for linking of entities linker_types: component deeppavlov.models.kbqa.entity_linking for linking of types rel_ranker: component deeppavlov.models.kbqa.rel_ranking_infer load_path: path to folder with wikidata files rank_rels_filename_1: file with list of rels for first rels in questions with ranking rank_rels_filename_2: file with list of rels for second rels in questions with ranking sparql_queries_filename: file with sparql query templates wiki_file_format: format of wikidata file wiki_parser: component deeppavlov.models.kbqa.wiki_parser entities_to_leave: how many entities to leave after entity linking rels_to_leave: how many relations to leave after relation ranking syntax_structure_known: if syntax tree parser was used to define query template type use_api_requester: whether deeppavlov.models.api_requester.api_requester component will be used for Entity Linking and Wiki Parser return_answers: whether to return answers or candidate answers """ super().__init__(save_path=None, load_path=load_path) self.template_matcher = template_matcher self.linker_entities = linker_entities self.linker_types = linker_types self.wiki_parser = wiki_parser self.wiki_file_format = wiki_file_format self.rel_ranker = rel_ranker self.rank_rels_filename_1 = rank_rels_filename_1 self.rank_rels_filename_2 = rank_rels_filename_2 self.rank_list_0 = [] self.rank_list_1 = [] self.entities_to_leave = entities_to_leave self.rels_to_leave = rels_to_leave self.syntax_structure_known = syntax_structure_known self.use_api_requester = use_api_requester self.sparql_queries_filename = sparql_queries_filename self.return_answers = return_answers self.load()
def load(self) -> None: with open(self.load_path / self.rank_rels_filename_1, 'r') as fl1: lines = fl1.readlines() self.rank_list_0 = [line.split('\t')[0] for line in lines] with open(self.load_path / self.rank_rels_filename_2, 'r') as fl2: lines = fl2.readlines() self.rank_list_1 = [line.split('\t')[0] for line in lines] self.template_queries = read_json(str(expand_path(self.sparql_queries_filename))) def save(self) -> None: pass def find_candidate_answers(self, question: str, question_sanitized: str, template_types: Union[List[str], str], entities_from_ner: List[str], types_from_ner: List[str]) -> Union[List[Tuple[str, Any]], List[str]]: candidate_outputs = [] self.template_nums = template_types replace_tokens = [(' - ', '-'), (' .', ''), ('{', ''), ('}', ''), (' ', ' '), ('"', "'"), ('(', ''), (')', ''), ('–', '-')] for old, new in replace_tokens: question = question.replace(old, new) entities_from_template, types_from_template, rels_from_template, rel_dirs_from_template, query_type_template, \ entity_types, template_answer, template_found = self.template_matcher(question_sanitized, entities_from_ner) self.template_nums = [query_type_template] log.debug(f"question: {question}\n") log.debug(f"template_type {self.template_nums}") if entities_from_template or types_from_template: if rels_from_template[0][0] == "PHOW": how_to_content = self.find_answer_wikihow(entities_from_template[0]) candidate_outputs = [["PHOW", how_to_content, 1.0]] else: entity_ids = self.get_entity_ids(entities_from_template, "entities", template_found, question, entity_types) type_ids = self.get_entity_ids(types_from_template, "types") log.debug(f"entities_from_template {entities_from_template}") log.debug(f"entity_types {entity_types}") log.debug(f"types_from_template {types_from_template}") log.debug(f"rels_from_template {rels_from_template}") log.debug(f"entity_ids {entity_ids}") log.debug(f"type_ids {type_ids}") candidate_outputs = self.sparql_template_parser(question_sanitized, entity_ids, type_ids, rels_from_template, rel_dirs_from_template) if not candidate_outputs and entities_from_ner: log.debug(f"(__call__)entities_from_ner: {entities_from_ner}") log.debug(f"(__call__)types_from_ner: {types_from_ner}") entity_ids = self.get_entity_ids(entities_from_ner, "entities", question=question) type_ids = self.get_entity_ids(types_from_ner, "types") log.debug(f"(__call__)entity_ids: {entity_ids}") log.debug(f"(__call__)type_ids: {type_ids}") self.template_nums = template_types log.debug(f"(__call__)self.template_nums: {self.template_nums}") if not self.syntax_structure_known: entity_ids = entity_ids[:3] candidate_outputs = self.sparql_template_parser(question_sanitized, entity_ids, type_ids) return candidate_outputs, template_answer def get_entity_ids(self, entities: List[str], what_to_link: str, template_found: str = None, question: str = None, entity_types: List[List[str]] = None) -> List[List[str]]: entity_ids = [] if what_to_link == "entities": if entity_types: el_output = self.linker_entities([entities], [template_found], [question], [entity_types]) else: el_output = self.linker_entities([entities], [template_found], [question]) if self.use_api_requester: el_output = el_output[0] entity_ids, _ = el_output if not self.use_api_requester and entity_ids: entity_ids = entity_ids[0] if what_to_link == "types": entity_ids, _ = self.linker_types([entities]) entity_ids = entity_ids[0] return entity_ids def sparql_template_parser(self, question: str, entity_ids: List[List[str]], type_ids: List[List[str]], rels_from_template: Optional[List[Tuple[str]]] = None, rel_dirs_from_template: Optional[List[str]] = None) -> List[Tuple[str]]: candidate_outputs = [] log.debug(f"(find_candidate_answers)self.template_nums: {self.template_nums}") templates = [] for template_num in self.template_nums: for num, template in self.template_queries.items(): if (num == template_num and self.syntax_structure_known) or \ (template["template_num"] == template_num and not self.syntax_structure_known): templates.append(template) templates = [template for template in templates if (not self.syntax_structure_known and [len(entity_ids), len(type_ids)] == template[ "entities_and_types_num"]) or self.syntax_structure_known] templates_string = '\n'.join([template["query_template"] for template in templates]) log.debug(f"{templates_string}") if not templates: return candidate_outputs if rels_from_template is not None: query_template = {} for template in templates: if template["rel_dirs"] == rel_dirs_from_template: query_template = template if query_template: entities_and_types_select = query_template["entities_and_types_select"] candidate_outputs = self.query_parser(question, query_template, entities_and_types_select, entity_ids, type_ids, rels_from_template) else: for template in templates: entities_and_types_select = template["entities_and_types_select"] candidate_outputs = self.query_parser(question, template, entities_and_types_select, entity_ids, type_ids, rels_from_template) if candidate_outputs: return candidate_outputs if not candidate_outputs: alternative_templates = templates[0]["alternative_templates"] for template_num, entities_and_types_select in alternative_templates: candidate_outputs = self.query_parser(question, self.template_queries[template_num], entities_and_types_select, entity_ids, type_ids, rels_from_template) if candidate_outputs: return candidate_outputs log.debug("candidate_rels_and_answers:\n" + '\n'.join([str(output) for output in candidate_outputs[:5]])) return candidate_outputs def find_top_rels(self, question: str, entity_ids: List[List[str]], triplet_info: Tuple) -> List[Tuple[str, Any]]: ex_rels = [] direction, source, rel_type = triplet_info if source == "wiki": queries_list = list({(entity, direction, rel_type) for entity_id in entity_ids for entity in entity_id[:self.entities_to_leave]}) parser_info_list = ["find_rels" for i in range(len(queries_list))] ex_rels = self.wiki_parser(parser_info_list, queries_list) if self.use_api_requester and ex_rels: ex_rels = [rel[0] for rel in ex_rels] ex_rels = list(set(ex_rels)) ex_rels = [rel.split('/')[-1] for rel in ex_rels] elif source == "rank_list_1": ex_rels = self.rank_list_0 elif source == "rank_list_2": ex_rels = self.rank_list_1 rels_with_scores = self.rel_ranker.rank_rels(question, ex_rels) return rels_with_scores[:self.rels_to_leave] def find_answer_wikihow(self, howto_sentence: str) -> str: search_results = search(howto_sentence, 5) article_id = search_results[0]["article_id"] html = get_html(article_id) page = BeautifulSoup(html, 'lxml') tags = list(page.find_all(['p'])) if tags: howto_content = f"{tags[0].text.strip()}@en" else: howto_content = "Not Found" return howto_content