Source code for deeppavlov.models.kbqa.query_generator_base

# Copyright 2017 Neural Networks and Deep Learning lab, MIPT
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# See the License for the specific language governing permissions and
# limitations under the License.

import itertools
import re
import time
from logging import getLogger
from typing import Tuple, List, Optional, Union, Dict, Any

import nltk

from deeppavlov.core.models.component import Component
from deeppavlov.core.models.serializable import Serializable
from deeppavlov.core.common.file import read_json
from deeppavlov.core.commands.utils import expand_path
from deeppavlov.models.kbqa.template_matcher import TemplateMatcher
from deeppavlov.models.kbqa.entity_linking import EntityLinker
from deeppavlov.models.kbqa.rel_ranking_infer import RelRankerInfer
from deeppavlov.models.kbqa.rel_ranking_bert_infer import RelRankerBertInfer

log = getLogger(__name__)

[docs]class QueryGeneratorBase(Component, Serializable): """ This class takes as input entity substrings, defines the template of the query and fills the slots of the template with candidate entities and relations. """
[docs] def __init__(self, template_matcher: TemplateMatcher, linker_entities: EntityLinker, linker_types: EntityLinker, rel_ranker: Union[RelRankerInfer, RelRankerBertInfer], load_path: str, rank_rels_filename_1: str, rank_rels_filename_2: str, sparql_queries_filename: str, wiki_parser = None, entities_to_leave: int = 5, rels_to_leave: int = 7, syntax_structure_known: bool = False, return_answers: bool = False, *args, **kwargs) -> None: """ Args: template_matcher: component deeppavlov.models.kbqa.template_matcher linker_entities: component deeppavlov.models.kbqa.entity_linking for linking of entities linker_types: component deeppavlov.models.kbqa.entity_linking for linking of types rel_ranker: component deeppavlov.models.kbqa.rel_ranking_infer load_path: path to folder with wikidata files rank_rels_filename_1: file with list of rels for first rels in questions with ranking rank_rels_filename_2: file with list of rels for second rels in questions with ranking sparql_queries_filename: file with sparql query templates wiki_parser: component deeppavlov.models.kbqa.wiki_parser entities_to_leave: how many entities to leave after entity linking rels_to_leave: how many relations to leave after relation ranking syntax_structure_known: if syntax tree parser was used to define query template type return_answers: whether to return answers or candidate answers """ super().__init__(save_path=None, load_path=load_path) self.template_matcher = template_matcher self.linker_entities = linker_entities self.linker_types = linker_types self.wiki_parser = wiki_parser self.rel_ranker = rel_ranker self.rank_rels_filename_1 = rank_rels_filename_1 self.rank_rels_filename_2 = rank_rels_filename_2 self.rank_list_0 = [] self.rank_list_1 = [] self.entities_to_leave = entities_to_leave self.rels_to_leave = rels_to_leave self.syntax_structure_known = syntax_structure_known self.sparql_queries_filename = sparql_queries_filename self.return_answers = return_answers self.load()
def load(self) -> None: with open(self.load_path / self.rank_rels_filename_1, 'r') as fl1: lines = fl1.readlines() self.rank_list_0 = [line.split('\t')[0] for line in lines] with open(self.load_path / self.rank_rels_filename_2, 'r') as fl2: lines = fl2.readlines() self.rank_list_1 = [line.split('\t')[0] for line in lines] self.template_queries = read_json(str(expand_path(self.sparql_queries_filename))) def save(self) -> None: pass def find_candidate_answers(self, question: str, question_sanitized: str, template_types: Union[List[str], str], entities_from_ner: List[str], types_from_ner: List[str]) -> Union[List[Tuple[str, Any]], List[str]]: candidate_outputs = [] self.template_nums = template_types replace_tokens = [(' - ', '-'), (' .', ''), ('{', ''), ('}', ''), (' ', ' '), ('"', "'"), ('(', ''), (')', ''), ('–', '-')] for old, new in replace_tokens: question = question.replace(old, new) tm1 = time.time() entities_from_template, types_from_template, rels_from_template, rel_dirs_from_template, \ query_type_template, template_found = self.template_matcher(question, entities_from_ner) self.template_nums = [query_type_template] log.debug(f"question: {question}\n") log.debug(f"template_type {self.template_nums}") if entities_from_template or types_from_template: tm1 = time.time() entity_ids = self.get_entity_ids(entities_from_template, "entities", template_found, question) type_ids = self.get_entity_ids(types_from_template, "types") log.debug(f"entities_from_template {entities_from_template}") log.debug(f"types_from_template {types_from_template}") log.debug(f"rels_from_template {rels_from_template}") log.debug(f"entity_ids {entity_ids}") log.debug(f"type_ids {type_ids}") tm1 = time.time() candidate_outputs = self.sparql_template_parser(question_sanitized, entity_ids, type_ids, rels_from_template, rel_dirs_from_template) if not candidate_outputs and entities_from_ner: log.debug(f"(__call__)entities_from_ner: {entities_from_ner}") log.debug(f"(__call__)types_from_ner: {types_from_ner}") entity_ids = self.get_entity_ids(entities_from_ner, "entities", question=question) type_ids = self.get_entity_ids(types_from_ner, "types") log.debug(f"(__call__)entity_ids: {entity_ids}") log.debug(f"(__call__)type_ids: {type_ids}") self.template_nums = template_types log.debug(f"(__call__)self.template_nums: {self.template_nums}") if not self.syntax_structure_known: entity_ids = entity_ids[:3] tm1 = time.time() candidate_outputs = self.sparql_template_parser(question_sanitized, entity_ids, type_ids) return candidate_outputs def get_entity_ids(self, entities: List[str], what_to_link: str, template_found: str = None, question: str = None) -> List[List[str]]: entity_ids = [] for entity in entities: entity_id = [] if what_to_link == "entities": entity_id, confidences = self.linker_entities.link_entity(entity, context=question, template_found=template_found) if what_to_link == "types": entity_id, confidences = self.linker_types.link_entity(entity) entity_ids.append(entity_id[:15]) return entity_ids def sparql_template_parser(self, question: str, entity_ids: List[List[str]], type_ids: List[List[str]], rels_from_template: Optional[List[Tuple[str]]] = None, rel_dirs_from_template: Optional[List[str]] = None) -> List[Tuple[str]]: candidate_outputs = [] log.debug(f"(find_candidate_answers)self.template_nums: {self.template_nums}") templates = [] for template_num in self.template_nums: for num, template in self.template_queries.items(): if (num == template_num and self.syntax_structure_known) or \ (template["template_num"] == template_num and not self.syntax_structure_known): templates.append(template) templates = [template for template in templates if \ (not self.syntax_structure_known and [len(entity_ids), len(type_ids)] == template["entities_and_types_num"]) \ or self.syntax_structure_known] templates_string = '\n'.join([template["query_template"] for template in templates]) log.debug(f"{templates_string}") if not templates: return candidate_outputs if rels_from_template is not None: query_template = {} for template in templates: if template["rel_dirs"] == rel_dirs_from_template: query_template = template if query_template: entities_and_types_select = query_template["entities_and_types_select"] candidate_outputs = self.query_parser(question, query_template, entities_and_types_select, entity_ids, type_ids, rels_from_template) else: for template in templates: entities_and_types_select = template["entities_and_types_select"] candidate_outputs = self.query_parser(question, template, entities_and_types_select, entity_ids, type_ids, rels_from_template) if candidate_outputs: return candidate_outputs if not candidate_outputs: alternative_templates = templates[0]["alternative_templates"] for template_num, entities_and_types_select in alternative_templates: candidate_outputs = self.query_parser(question, self.template_queries[template_num], entities_and_types_select, entity_ids, type_ids, rels_from_template) return candidate_outputs log.debug("candidate_rels_and_answers:\n" + '\n'.join([str(output) for output in candidate_outputs[:5]])) return candidate_outputs def find_top_rels(self, question: str, entity_ids: List[List[str]], triplet_info: Tuple) -> List[str]: ex_rels = [] direction, source, rel_type = triplet_info if source == "wiki": for entity_id in entity_ids: for entity in entity_id[:self.entities_to_leave]: ex_rels += self.wiki_parser.find_rels(entity, direction, rel_type) ex_rels = list(set(ex_rels)) ex_rels = [rel.split('/')[-1] for rel in ex_rels] elif source == "rank_list_1": ex_rels = self.rank_list_0 elif source == "rank_list_2": ex_rels = self.rank_list_1 rels_with_scores = self.rel_ranker.rank_rels(question, ex_rels) return rels_with_scores[:self.rels_to_leave]