Source code for deeppavlov.models.kbqa.rel_ranking_infer

# Copyright 2017 Neural Networks and Deep Learning lab, MIPT
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from logging import getLogger
from typing import Tuple, List, Any, Optional

from scipy.special import softmax

from deeppavlov.core.common.chainer import Chainer
from deeppavlov.core.common.file import load_pickle
from deeppavlov.core.common.registry import register
from deeppavlov.core.models.component import Component
from deeppavlov.core.models.serializable import Serializable
from deeppavlov.models.kbqa.sentence_answer import sentence_answer
from deeppavlov.models.kbqa.wiki_parser import WikiParser

log = getLogger(__name__)


[docs]@register('rel_ranking_infer') class RelRankerInfer(Component, Serializable): """Class for ranking of paths in subgraph"""
[docs] def __init__(self, load_path: str, rel_q2name_filename: str, ranker: Chainer = None, wiki_parser: Optional[WikiParser] = None, batch_size: int = 32, rels_to_leave: int = 40, softmax: bool = False, return_all_possible_answers: bool = False, return_answer_ids: bool = False, use_api_requester: bool = False, return_sentence_answer: bool = False, rank: bool = True, return_confidences: bool = False, **kwargs): """ Args: load_path: path to folder with wikidata files rel_q2name_filename: name of file which maps relation id to name ranker: component deeppavlov.models.ranking.rel_ranker wiki_parser: component deeppavlov.models.wiki_parser batch_size: infering batch size rels_to_leave: how many relations to leave after relation ranking softmax: whether to process relation scores with softmax function return_all_possible_answers: whether to return all found answers return_answer_ids: whether to return answer ids from Wikidata use_api_requester: whether wiki parser will be used as external api return_sentence_answer: whether to return answer as a sentence rank: whether to rank relations or simple copy input return_confidences: whether to return confidences of candidate answers **kwargs: """ super().__init__(save_path=None, load_path=load_path) self.rel_q2name_filename = rel_q2name_filename self.ranker = ranker self.wiki_parser = wiki_parser self.batch_size = batch_size self.rels_to_leave = rels_to_leave self.softmax = softmax self.return_all_possible_answers = return_all_possible_answers self.return_answer_ids = return_answer_ids self.use_api_requester = use_api_requester self.return_sentence_answer = return_sentence_answer self.rank = rank self.return_confidences = return_confidences self.load()
def load(self) -> None: self.rel_q2name = load_pickle(self.load_path / self.rel_q2name_filename) def save(self) -> None: pass
[docs] def __call__(self, questions_list: List[str], candidate_answers_list: List[List[Tuple[str]]], entities_list: List[List[str]] = None, template_answers_list: List[str] = None) -> List[str]: answers = [] confidence = 0.0 if entities_list is None: entities_list = [[] for _ in questions_list] if template_answers_list is None: template_answers_list = ["" for _ in questions_list] for question, candidate_answers, entities, template_answer in \ zip(questions_list, candidate_answers_list, entities_list, template_answers_list): answers_with_scores = [] answer = "Not Found" if self.rank: n_batches = len(candidate_answers) // self.batch_size + int( len(candidate_answers) % self.batch_size > 0) for i in range(n_batches): questions_batch = [] rels_batch = [] rels_labels_batch = [] answers_batch = [] entities_batch = [] confidences_batch = [] for candidate_ans_and_rels in candidate_answers[i * self.batch_size: (i + 1) * self.batch_size]: candidate_rels = [] candidate_rels_str, candidate_answer = "", "" candidate_entities, candidate_confidence = [], [] if candidate_ans_and_rels: candidate_rels = candidate_ans_and_rels["relations"] candidate_rels = [candidate_rel.split('/')[-1] for candidate_rel in candidate_rels] candidate_answer = candidate_ans_and_rels["answers"] candidate_entities = candidate_ans_and_rels["entities"] candidate_confidence = candidate_ans_and_rels["rel_conf"] candidate_rels_str = " # ".join([self.rel_q2name[candidate_rel] \ for candidate_rel in candidate_rels if candidate_rel in self.rel_q2name]) if candidate_rels_str: questions_batch.append(question) rels_batch.append(candidate_rels) rels_labels_batch.append(candidate_rels_str) answers_batch.append(candidate_answer) entities_batch.append(candidate_entities) confidences_batch.append(candidate_confidence) if questions_batch: probas = self.ranker(questions_batch, rels_labels_batch) probas = [proba[1] for proba in probas] for j, (answer, entities, confidence, rels_ids, rels_labels) in \ enumerate(zip(answers_batch, entities_batch, confidences_batch, rels_batch, rels_labels_batch)): answers_with_scores.append( (answer, entities, rels_labels, rels_ids, max(probas[j], confidence))) answers_with_scores = sorted(answers_with_scores, key=lambda x: x[-1], reverse=True) else: answers_with_scores = [(answer, rels, conf) for *rels, answer, conf in candidate_answers] answer_ids = tuple() if answers_with_scores: log.debug(f"answers: {answers_with_scores[0]}") answer_ids = answers_with_scores[0][0] if self.return_all_possible_answers and isinstance(answer_ids, tuple): answer_ids_input = [(answer_id, question) for answer_id in answer_ids] answer_ids = list(map(lambda x: x.split("/")[-1] if str(x).startswith("http") else x, answer_ids)) else: answer_ids_input = [(answer_ids, question)] if str(answer_ids).startswith("http:"): answer_ids = answer_ids.split("/")[-1] parser_info_list = ["find_label" for _ in answer_ids_input] answer_labels = self.wiki_parser(parser_info_list, answer_ids_input) log.debug(f"answer_labels {answer_labels}") if self.return_all_possible_answers: answer_labels = list(set(answer_labels)) answer_labels = [label for label in answer_labels if (label and label != "Not Found")][:5] answer_labels = [str(label) for label in answer_labels] if len(answer_labels) > 2: answer = f"{', '.join(answer_labels[:-1])} and {answer_labels[-1]}" else: answer = ', '.join(answer_labels) else: answer = answer_labels[0] if self.return_sentence_answer: try: answer = sentence_answer(question, answer, entities, template_answer) except: log.warning("Error in sentence answer") confidence = answers_with_scores[0][2] if self.return_confidences: answers.append((answer, confidence)) else: if self.return_answer_ids: if not answer_ids: answer_ids = "Not found" answers.append((answer, answer_ids)) else: answers.append(answer) if not answers: if self.return_confidences: answers.append(("Not found", 0.0)) else: answers.append("Not found") return answers
def rank_rels(self, question: str, candidate_rels: List[str]) -> List[Tuple[str, Any]]: rels_with_scores = [] if question is not None: n_batches = len(candidate_rels) // self.batch_size + int(len(candidate_rels) % self.batch_size > 0) for i in range(n_batches): questions_batch = [] rels_labels_batch = [] rels_batch = [] for candidate_rel in candidate_rels[i * self.batch_size: (i + 1) * self.batch_size]: if candidate_rel in self.rel_q2name: questions_batch.append(question) rels_batch.append(candidate_rel) rels_labels_batch.append(self.rel_q2name[candidate_rel]) if questions_batch: probas = self.ranker(questions_batch, rels_labels_batch) probas = [proba[1] for proba in probas] for j, rel in enumerate(rels_batch): rels_with_scores.append((rel, probas[j])) if self.softmax: scores = [score for rel, score in rels_with_scores] softmax_scores = softmax(scores) rels_with_scores = [(rel, softmax_score) for (rel, score), softmax_score in zip(rels_with_scores, softmax_scores)] rels_with_scores = sorted(rels_with_scores, key=lambda x: x[1], reverse=True) return rels_with_scores[:self.rels_to_leave]