Source code for deeppavlov.models.doc_retrieval.logit_ranker

# Copyright 2017 Neural Networks and Deep Learning lab, MIPT
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import List, Union, Tuple
from operator import itemgetter
import warnings

from deeppavlov.core.common.registry import register
from deeppavlov.core.common.log import get_logger
from deeppavlov.core.models.estimator import Component
from deeppavlov.core.common.chainer import Chainer

logger = get_logger(__name__)


[docs]@register("logit_ranker") class LogitRanker(Component): """Select best answer using squad model logits. Make several batches for a single batch, send each batch to the squad model separately and get a single best answer for each batch. Args: squad_model: a loaded squad model batch_size: batch size to use with squad model sort_noans: whether to downgrade noans tokens in the most possible answers Attributes: squad_model: a loaded squad model batch_size: batch size to use with squad model """ def __init__(self, squad_model: Union[Chainer, Component], batch_size: int = 50, sort_noans: bool = False, **kwargs): self.squad_model = squad_model self.batch_size = batch_size self.sort_noans = sort_noans
[docs] def __call__(self, contexts_batch: List[List[str]], questions_batch: List[List[str]]) -> \ Tuple[List[str], List[float]]: """ Sort obtained results from squad reader by logits and get the answer with a maximum logit. Args: contexts_batch: a batch of contexts which should be treated as a single batch in the outer JSON config questions_batch: a batch of questions which should be treated as a single batch in the outer JSON config Returns: a batch of best answers and their scores """ # TODO output result for top_n warnings.warn(f'{self.__class__.__name__}.__call__() API will be changed in the future release.' ' Instead of returning Tuple(List[str], List[float] will return' ' Tuple(List[List[str]], List[List[float]]).', FutureWarning) batch_best_answers = [] batch_best_answers_scores = [] for contexts, questions in zip(contexts_batch, questions_batch): results = [] for i in range(0, len(contexts), self.batch_size): c_batch = contexts[i: i + self.batch_size] q_batch = questions[i: i + self.batch_size] batch_predict = zip(*self.squad_model(c_batch, q_batch)) results += batch_predict if self.sort_noans: results = sorted(results, key=lambda x: (x[0] != '', x[2]), reverse=True) else: results = sorted(results, key=itemgetter(2), reverse=True) batch_best_answers.append(results[0][0]) batch_best_answers_scores.append(results[0][2]) return batch_best_answers, batch_best_answers_scores