Source code for deeppavlov.models.kbqa.rel_ranking_infer

# Copyright 2017 Neural Networks and Deep Learning lab, MIPT
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Tuple, List, Any

from scipy.special import softmax

from deeppavlov.core.common.registry import register
from deeppavlov.core.models.component import Component
from deeppavlov.core.models.serializable import Serializable
from deeppavlov.core.common.file import load_pickle
from deeppavlov.models.ranking.rel_ranker import RelRanker


[docs]@register('rel_ranking_infer') class RelRankerInfer(Component, Serializable): """This class performs ranking of candidate relations"""
[docs] def __init__(self, load_path: str, rel_q2name_filename: str, ranker: RelRanker, rels_to_leave: int = 15, batch_size: int = 100, **kwargs): """ Args: load_path: path to folder with wikidata files rel_q2name_filename: name of file which maps relation id to name ranker: deeppavlov.models.ranking.rel_ranker rels_to_leave: how many top scored relations leave batch_size: infering batch size **kwargs: """ super().__init__(save_path=None, load_path=load_path) self.rel_q2name_filename = rel_q2name_filename self.ranker = ranker self.rels_to_leave = rels_to_leave self.batch_size = batch_size self.load()
def load(self) -> None: self.rel_q2name = load_pickle(self.load_path / self.rel_q2name_filename) def save(self) -> None: pass
[docs] def __call__(self, question_batch: List[str], candidate_rels_batch: List[List[str]]) -> \ List[List[Tuple[str, Any]]]: rels_with_scores_batch = [] for question, candidate_rels in zip(question_batch, candidate_rels_batch): rels_with_scores_batch.append(self.rank_rels(question, candidate_rels)) return rels_with_scores_batch
def rank_rels(self, question: str, candidate_rels: List[str]) -> List[Tuple[str, Any]]: rels_with_scores = [] n_batches = len(candidate_rels) // self.batch_size + int(len(candidate_rels) % self.batch_size > 0) for i in range(n_batches): questions_batch = [] rels_labels_batch = [] rels_batch = [] for candidate_rel in candidate_rels[i * self.batch_size: (i + 1) * self.batch_size]: if candidate_rel in self.rel_q2name: questions_batch.append(question) rels_batch.append(candidate_rel) rels_labels_batch.append(self.rel_q2name[candidate_rel]) if questions_batch: probas = self.ranker(questions_batch, rels_labels_batch) probas = [proba[1] for proba in probas] for j, rel in enumerate(rels_batch): rels_with_scores.append((rel, probas[j])) scores = [score for rel, score in rels_with_scores] if scores: softmax_scores = softmax(scores) rels_with_scores = [(rel, softmax_score) for (rel, score), softmax_score in zip(rels_with_scores, softmax_scores)] rels_with_scores = sorted(rels_with_scores, key=lambda x: x[1], reverse=True) return rels_with_scores[:self.rels_to_leave]