Source code for deeppavlov.models.kbqa.kb_answer_parser_simple

# Copyright 2017 Neural Networks and Deep Learning lab, MIPT
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from logging import getLogger
from typing import List, Tuple, Optional, Union

import numpy as np

from deeppavlov.core.common.registry import register
from deeppavlov.models.kbqa.kb_answer_parser_base import KBBase

log = getLogger(__name__)


[docs]@register('kb_answer_parser_simple') class KBAnswerParserSimple(KBBase): """ This class generates an answer for a given question using Wikidata. It searches for matching triplet from the Wikidata with entity and relation mentioned in the question. It uses results of the Named Entity Recognition component to extract entity mention and Classification component to determine relation which connects extracted entity and the answer entity. """ def __init__(self, top_k_classes: int, rule_filter_entities: bool = False, return_confidences: bool = False, language: str = "eng", relations_maping_filename: Optional[str] = None, templates_filename: Optional[str] = None, *args, **kwargs) -> None: """ Args: top_k_classes: number of relations with top k probabilities rule_filter_entities: whether to filter entities with rules language: russian or english relations_maping_filename: file with the dictionary of ids(keys) and titles(values) of relations from Wikidata templates_filename: file with the dictionary of question templates(keys) and relations for these templates (values) *args **kwargs """ self.top_k_classes = top_k_classes self.rule_filter_entities = rule_filter_entities self.return_confidences = return_confidences self.language = language self._relations_filename = relations_maping_filename self._templates_filename = templates_filename super().__init__(relations_maping_filename=self._relations_filename, *args, **kwargs) def __call__(self, questions_batch: List[str], tokens_batch: List[List[str]], tags_batch: List[List[int]], relations_probs_batch: List[List[float]], relations_labels_batch: List[List[str]], *args, **kwargs) -> Union[Tuple[List[str], List[float]], List[str]]: objects_batch = [] confidences_batch = [] for question, tokens, tags, relations_probs, relations_labels in \ zip(questions_batch, tokens_batch, tags_batch, relations_probs_batch, relations_labels_batch): is_kbqa = self.is_kbqa_question(question, self.language) entity_from_template = [] if is_kbqa: if self._templates_filename is not None: entity_from_template, _, relations_from_template, _, query_type = self.template_matcher(question) if entity_from_template: relation_from_template = relations_from_template[0][0] relation_title = self._relations_mapping[relation_from_template]["name"] log.debug("entity {}, relation {}".format(entity_from_template, relation_title)) entity_ids, entity_linking_confidences = self.linker(entity_from_template[0]) log.debug(f"entity_ids {entity_ids[:5]}") entity_triplets = self.extract_triplets_from_wiki(entity_ids) if self.rule_filter_entities and self.language == 'rus': entity_ids, entity_triplets, entity_linking_confidences = \ self.filter_triplets_rus(entity_triplets, entity_linking_confidences, tokens, entity_ids) relation_prob = 1.0 obj, confidence = self.match_triplet(entity_triplets, entity_linking_confidences, [relation_from_template], [relation_prob]) else: entity_from_ner = self.extract_entities(tokens, tags) entity_ids, entity_linking_confidences = self.linker(entity_from_ner) entity_triplets = self.extract_triplets_from_wiki(entity_ids) if self.rule_filter_entities and self.language == 'rus': entity_ids, entity_triplets, entity_linking_confidences = \ self.filter_triplets_rus(entity_triplets, entity_linking_confidences, tokens, entity_ids) top_k_probs = self._parse_relations_probs(relations_probs) top_k_relation_names = [self._relations_mapping[rel]["name"] for rel in relations_labels] log.debug("entity_from_ner {}, top k relations {}".format(str(entity_from_ner), str(top_k_relation_names))) obj, confidence = self.match_triplet(entity_triplets, entity_linking_confidences, relations_labels, top_k_probs) objects_batch.append(obj) confidences_batch.append(confidence) else: objects_batch.append('') confidences_batch.append(0.0) parsed_objects_batch, confidences_batch = self.parse_wikidata_object(objects_batch, confidences_batch) if self.return_confidences: return parsed_objects_batch, confidences_batch else: return parsed_objects_batch def _parse_relations_probs(self, probs: List[float]) -> List[float]: top_k_inds = np.asarray(probs).argsort()[-self.top_k_classes:][::-1] top_k_probs = [probs[k] for k in top_k_inds] return top_k_probs @staticmethod def extract_entities(tokens: List[str], tags: List[int]) -> str: entity = [] for j, tok in enumerate(tokens): if tags[j] != 'O' and tags[j] != 0: entity.append(tok) entity = ' '.join(entity) return entity def extract_triplets_from_wiki(self, entity_ids: List[str]) -> List[List[List[str]]]: entity_triplets = [] for entity_id in entity_ids: if entity_id in self.wikidata and entity_id.startswith('Q'): triplets_for_entity = self.wikidata[entity_id] entity_triplets.append(triplets_for_entity) else: entity_triplets.append([]) return entity_triplets def filter_triplets_rus(self, entity_triplets: List[List[List[str]]], confidences: List[float], question_tokens: List[str], srtd_cand_ent: List[Tuple[str]]) -> \ Tuple[List[Tuple[str]], List[List[List[str]]], List[float]]: question = ' '.join(question_tokens).lower() what_template = 'что ' found_what_template = question.find(what_template) > -1 filtered_entity_triplets = [] filtered_entities = [] filtered_confidences = [] for wiki_entity, confidence, triplets_for_entity in zip(srtd_cand_ent, confidences, entity_triplets): entity_is_human = False entity_is_asteroid = False entity_is_named = False entity_title = wiki_entity if entity_title[0].isupper(): entity_is_named = True property_is_instance_of = 'P31' id_for_entity_human = 'Q5' id_for_entity_asteroid = 'Q3863' for triplet in triplets_for_entity: if triplet[0] == property_is_instance_of and triplet[1] == id_for_entity_human: entity_is_human = True break if triplet[0] == property_is_instance_of and triplet[1] == id_for_entity_asteroid: entity_is_asteroid = True break if found_what_template and (entity_is_human or entity_is_named or entity_is_asteroid): continue filtered_entity_triplets.append(triplets_for_entity) filtered_entities.append(wiki_entity) filtered_confidences.append(confidence) return filtered_entities, filtered_entity_triplets, filtered_confidences