Source code for deeppavlov.models.classifiers.torch_classification_model

# Copyright 2017 Neural Networks and Deep Learning lab, MIPT
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import logging
from typing import List, Union, Optional

import numpy as np
import torch

from deeppavlov.core.common.errors import ConfigError
from deeppavlov.core.common.registry import register
from deeppavlov.core.models.torch_model import TorchModel
from .torch_nets import ShallowAndWideCnn

log = logging.getLogger(__name__)


[docs]@register('torch_text_classification_model') class TorchTextClassificationModel(TorchModel): """Class implements torch model for classification of texts. Input can either be embedded tokenized texts OR indices of words in the vocabulary. Number of tokens is not fixed while the samples in batch should be padded to the same (e.g. longest) lengths. Args: n_classes: number of classes kernel_sizes_cnn: list of kernel sizes of convolutions filters_cnn: number of filters for convolutions dense_size: number of units for dense layer dropout_rate: dropout rate, after convolutions and between dense embedding_size: size of vector representation of words multilabel: is multi-label classification (if so, `sigmoid` activation will be used, otherwise, softmax) criterion: criterion name from `torch.nn` embedded_tokens: True, if input contains embedded tokenized texts; False, if input containes indices of words in the vocabulary vocab_size: vocabulary size in case of `embedded_tokens=False`, and embedding is a layer in the Network return_probas: whether to return probabilities or index of classes (only for `multilabel=False`) Attributes: model: torch model itself epochs_done: number of epochs that were done criterion: torch criterion instance """ def __init__(self, n_classes: int, kernel_sizes_cnn: List[int], filters_cnn: int, dense_size: int, dropout_rate: float = 0.0, embedding_size: Optional[int] = None, multilabel: bool = False, criterion: str = "CrossEntropyLoss", embedded_tokens: bool = True, vocab_size: Optional[int] = None, return_probas: bool = True, **kwargs): if n_classes == 0: raise ConfigError("Please, provide vocabulary with considered classes or number of classes.") if multilabel and not return_probas: raise RuntimeError('Set return_probas to True for multilabel classification!') self.multilabel = multilabel self.return_probas = return_probas model = ShallowAndWideCnn( n_classes=n_classes, embedding_size=embedding_size, kernel_sizes_cnn=kernel_sizes_cnn, filters_cnn=filters_cnn, dense_size=dense_size, dropout_rate=dropout_rate, embedded_tokens=embedded_tokens, vocab_size=vocab_size ) self.criterion = getattr(torch.nn, criterion)() super().__init__(model, **kwargs)
[docs] def __call__(self, texts: List[np.ndarray], *args) -> Union[List[List[float]], List[int]]: """Infer on the given data. Args: texts: list of tokenized text samples labels: labels *args: additional arguments Returns: for each sentence: vector of probabilities to belong with each class or list of labels sentence belongs with """ with torch.no_grad(): features = np.array(texts) inputs = torch.from_numpy(features) inputs = inputs.to(self.device) outputs = self.model(inputs) if self.multilabel: outputs = torch.nn.functional.sigmoid(outputs) else: outputs = torch.nn.functional.softmax(outputs, dim=-1) outputs = outputs.cpu().detach().numpy() if self.return_probas: return outputs.tolist() else: return np.argmax(outputs, axis=-1).tolist()
[docs] def train_on_batch(self, texts: List[List[np.ndarray]], labels: list) -> Union[float, List[float]]: """Train the model on the given batch. Args: texts: vectorized texts labels: list of labels Returns: metrics values on the given batch """ features, labels = np.array(texts), np.array(labels) inputs, labels = torch.from_numpy(features), torch.from_numpy(labels) inputs, labels = inputs.to(self.device), labels.to(self.device) # zero the parameter gradients self.optimizer.zero_grad() # forward + backward + optimize outputs = self.model(inputs) labels = labels.view(-1).long() loss = self.criterion(outputs, labels) self._make_step(loss) return loss.item()