Source code for deeppavlov.models.kbqa.rel_ranking_infer

# Copyright 2017 Neural Networks and Deep Learning lab, MIPT
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from logging import getLogger
from typing import Tuple, List, Any, Optional

from scipy.special import softmax

from deeppavlov.core.common.chainer import Chainer
from deeppavlov.core.common.file import load_pickle, read_json
from deeppavlov.core.common.registry import register
from deeppavlov.core.models.component import Component
from deeppavlov.core.models.serializable import Serializable
from deeppavlov.models.kbqa.sentence_answer import sentence_answer
from deeppavlov.models.kbqa.wiki_parser import WikiParser

log = getLogger(__name__)


[docs]@register('rel_ranking_infer') class RelRankerInfer(Component, Serializable): """Class for ranking of paths in subgraph"""
[docs] def __init__(self, load_path: str, rel_q2name_filename: str, return_elements: List[str] = None, ranker: Chainer = None, wiki_parser: Optional[WikiParser] = None, batch_size: int = 32, softmax: bool = False, use_api_requester: bool = False, rank: bool = True, nll_rel_ranking: bool = False, nll_path_ranking: bool = False, top_possible_answers: int = -1, top_n: int = 1, pos_class_num: int = 1, rel_thres: float = 0.0, type_rels: List[str] = None, **kwargs): """ Args: load_path: path to folder with wikidata files rel_q2name_filename: name of file which maps relation id to name return_elements: what elements return in output ranker: component deeppavlov.models.ranking.rel_ranker wiki_parser: component deeppavlov.models.wiki_parser batch_size: infering batch size softmax: whether to process relation scores with softmax function use_api_requester: whether wiki parser will be used as external api rank: whether to rank relations or simple copy input nll_rel_ranking: whether use components trained with nll loss for relation ranking nll_path_ranking: whether use components trained with nll loss for relation path ranking top_possible_answers: number of answers returned for a question in each list of candidate answers top_n: number of lists of candidate answers returned for a question pos_class_num: index of positive class in the output of relation ranking model rel_thres: threshold of relation confidence type_rels: list of relations in the knowledge base which connect an entity and its type **kwargs: """ super().__init__(save_path=None, load_path=load_path) self.rel_q2name_filename = rel_q2name_filename self.ranker = ranker self.wiki_parser = wiki_parser self.batch_size = batch_size self.softmax = softmax self.return_elements = return_elements or list() self.use_api_requester = use_api_requester self.rank = rank self.nll_rel_ranking = nll_rel_ranking self.nll_path_ranking = nll_path_ranking self.top_possible_answers = top_possible_answers self.top_n = top_n self.pos_class_num = pos_class_num self.rel_thres = rel_thres self.type_rels = type_rels or set() self.load()
def load(self) -> None: if self.rel_q2name_filename.endswith("pickle"): self.rel_q2name = load_pickle(self.load_path / self.rel_q2name_filename) elif self.rel_q2name_filename.endswith("json"): self.rel_q2name = read_json(self.load_path / self.rel_q2name_filename) def save(self) -> None: pass
[docs] def __call__(self, questions_batch: List[str], template_type_batch: List[str], raw_answers_batch: List[List[Tuple[str]]], entity_substr_batch: List[List[str]], template_answers_batch: List[str]) -> List[str]: answers_batch, outp_confidences_batch, answer_ids_batch = [], [], [] entities_and_rels_batch, queries_batch, triplets_batch = [], [], [] for question, template_type, raw_answers, entities, template_answer in \ zip(questions_batch, template_type_batch, raw_answers_batch, entity_substr_batch, template_answers_batch): answers_with_scores = [] l_questions, l_rels, l_rels_labels, l_cur_answers, l_entities, l_types, l_sparql_queries, l_triplets, \ l_confs = self.preprocess_ranking_input(question, raw_answers) n_batches = len(l_questions) // self.batch_size + int(len(l_questions) % self.batch_size > 0) for i in range(n_batches): if self.rank: if self.nll_path_ranking: probas = self.ranker([l_questions[0]], [l_rels_labels[self.batch_size * i:self.batch_size * (i + 1)]]) probas = probas[0] else: probas = self.ranker(l_questions[self.batch_size * i:self.batch_size * (i + 1)], l_rels_labels[self.batch_size * i:self.batch_size * (i + 1)]) probas = [proba[0] for proba in probas] else: probas = [rel_conf for rel_conf, entity_conf in l_confs[self.batch_size * i:self.batch_size * (i + 1)]] for j in range(self.batch_size * i, self.batch_size * (i + 1)): if j < len(l_cur_answers) and (probas[j - self.batch_size * i] > self.rel_thres or (len(l_rels[j]) > 1 and not set(l_rels[j]).intersection( self.type_rels))): answers_with_scores.append((l_cur_answers[j], l_sparql_queries[j], l_triplets[j], l_entities[j], l_types[j], l_rels_labels[j], l_rels[j], round(probas[j - self.batch_size * i], 3), round(l_confs[j][0], 3), l_confs[j][1])) answers_with_scores = sorted(answers_with_scores, key=lambda x: x[-1] * x[-3], reverse=True) if template_type == "simple_boolean" and not answers_with_scores: answers_with_scores = [(["No"], "", [], [], [], [], [], 1.0, 1.0, 1.0)] res_answers_list, res_answer_ids_list, res_confidences_list, res_entities_and_rels_list = [], [], [], [] res_queries_list, res_triplets_list = [], [] for n, ans_sc_elem in enumerate(answers_with_scores): init_answer_ids, query, triplets, q_entities, q_types, _, q_rels, p_conf, r_conf, e_conf = ans_sc_elem answer_ids = [] for answer_id in init_answer_ids: answer_id = str(answer_id).replace("@en", "").strip('"') if answer_id not in answer_ids: answer_ids.append(answer_id) if self.top_possible_answers > 0: answer_ids = answer_ids[:self.top_possible_answers] answer_ids_input = [(answer_id, question) for answer_id in answer_ids] answer_ids = [str(answer_id).split("/")[-1] for answer_id in answer_ids] parser_info_list = ["find_label" for _ in answer_ids_input] init_answer_labels = self.wiki_parser(parser_info_list, answer_ids_input) if n < 7: log.debug(f"answers: {init_answer_ids[:3]} --- query {query} --- entities {q_entities} --- " f"types {q_types[:3]} --- q_rels {q_rels} --- {ans_sc_elem[5:]} --- " f"answer_labels {init_answer_labels[:3]}") answer_labels = [] for label in init_answer_labels: if label not in answer_labels: answer_labels.append(label) answer_labels = [label for label in answer_labels if (label and label != "Not Found")][:5] answer_labels = [str(label) for label in answer_labels] if len(answer_labels) > 2: answer = f"{', '.join(answer_labels[:-1])} and {answer_labels[-1]}" else: answer = ', '.join(answer_labels) if "sentence_answer" in self.return_elements: try: answer = sentence_answer(question, answer, entities, template_answer) except ValueError as e: log.warning(f"Error in sentence answer, {e}") res_answers_list.append(answer) res_answer_ids_list.append(answer_ids) if "several_confidences" in self.return_elements: res_confidences_list.append((p_conf, r_conf, e_conf)) else: res_confidences_list.append(p_conf) res_entities_and_rels_list.append([q_entities[:-1], q_rels]) res_queries_list.append(query) res_triplets_list.append(triplets) if self.top_n == 1: if answers_with_scores: answers_batch.append(res_answers_list[0]) outp_confidences_batch.append(res_confidences_list[0]) answer_ids_batch.append(res_answer_ids_list[0]) entities_and_rels_batch.append(res_entities_and_rels_list[0]) queries_batch.append(res_queries_list[0]) triplets_batch.append(res_triplets_list[0]) else: answers_batch.append("Not Found") outp_confidences_batch.append(0.0) answer_ids_batch.append([]) entities_and_rels_batch.append([]) queries_batch.append([]) triplets_batch.append([]) else: answers_batch.append(res_answers_list[:self.top_n]) outp_confidences_batch.append(res_confidences_list[:self.top_n]) answer_ids_batch.append(res_answer_ids_list[:self.top_n]) entities_and_rels_batch.append(res_entities_and_rels_list[:self.top_n]) queries_batch.append(res_queries_list[:self.top_n]) triplets_batch.append(res_triplets_list[:self.top_n]) answer_tuple = (answers_batch,) if "confidences" in self.return_elements: answer_tuple += (outp_confidences_batch,) if "answer_ids" in self.return_elements: answer_tuple += (answer_ids_batch,) if "entities_and_rels" in self.return_elements: answer_tuple += (entities_and_rels_batch,) if "queries" in self.return_elements: answer_tuple += (queries_batch,) if "triplets" in self.return_elements: answer_tuple += (triplets_batch,) return answer_tuple
def preprocess_ranking_input(self, question, answers): l_questions, l_rels, l_rels_labels, l_cur_answers = [], [], [], [] l_entities, l_types, l_sparql_queries, l_triplets, l_confs = [], [], [], [], [] for ans_and_rels in answers: answer, sparql_query, confidence = "", "", [] entities, types, rels, rels_labels, triplets = [], [], [], [], [] if ans_and_rels: rels = [rel.split('/')[-1] for rel in ans_and_rels["relations"]] answer = ans_and_rels["answers"] entities = ans_and_rels["entities"] types = ans_and_rels["types"] sparql_query = ans_and_rels["sparql_query"] triplets = ans_and_rels["triplets"] confidence = ans_and_rels["output_conf"] rels_labels = [] for rel in rels: if rel in self.rel_q2name: label = self.rel_q2name[rel] if isinstance(label, list): label = label[0] rels_labels.append(label.lower()) if rels_labels: l_questions.append(question) l_rels.append(rels) l_rels_labels.append(rels_labels) l_cur_answers.append(answer) l_entities.append(entities) l_types.append(types) l_sparql_queries.append(sparql_query) l_triplets.append(triplets) l_confs.append(confidence) return l_questions, l_rels, l_rels_labels, l_cur_answers, l_entities, l_types, l_sparql_queries, l_triplets, \ l_confs def rank_rels(self, question: str, candidate_rels: List[str]) -> List[Tuple[str, Any]]: rels_with_scores = [] if question is not None: questions, rels_labels, rels = [], [], [] for candidate_rel in candidate_rels: if candidate_rel in self.rel_q2name: cur_rels_labels = self.rel_q2name[candidate_rel] if isinstance(cur_rels_labels, str): cur_rels_labels = [cur_rels_labels] for cur_rel in cur_rels_labels: questions.append(question) rels.append(candidate_rel) rels_labels.append(cur_rel) if questions: n_batches = len(rels) // self.batch_size + int(len(rels) % self.batch_size > 0) for i in range(n_batches): if self.nll_rel_ranking: probas = self.ranker([questions[0]], [rels_labels[i * self.batch_size:(i + 1) * self.batch_size]]) probas = probas[0] else: probas = self.ranker(questions[i * self.batch_size:(i + 1) * self.batch_size], rels_labels[i * self.batch_size:(i + 1) * self.batch_size]) probas = [proba[self.pos_class_num] for proba in probas] for j, rel in enumerate(rels[i * self.batch_size:(i + 1) * self.batch_size]): rels_with_scores.append((rel, probas[j])) if self.softmax: scores = [score for rel, score in rels_with_scores] softmax_scores = softmax(scores) rels_with_scores = [(rel, softmax_score) for (rel, score), softmax_score in zip(rels_with_scores, softmax_scores)] rels_with_scores_dict = {} for rel, score in rels_with_scores: if rel not in rels_with_scores_dict: rels_with_scores_dict[rel] = [] rels_with_scores_dict[rel].append(score) rels_with_scores = [(rel, max(scores)) for rel, scores in rels_with_scores_dict.items()] rels_with_scores = sorted(rels_with_scores, key=lambda x: x[1], reverse=True) return rels_with_scores