from logging import getLogger
from typing import List, Optional, Dict, Tuple, Union

import numpy as np
import torch
from torch import nn, Tensor

from deeppavlov.core.common.errors import ConfigError
from deeppavlov.core.common.registry import register
from deeppavlov.core.models.torch_model import TorchModel
from deeppavlov.models.classifiers.re_bert import BertWithAdaThresholdLocContextPooling

log = getLogger(__name__)

[docs]@register('re_classifier') class REBertModel(TorchModel):
[docs] def __init__( self, n_classes: int, model_name: str, num_ner_tags: int, pretrained_bert: str = None, criterion: str = "CrossEntropyLoss", optimizer: str = "AdamW", optimizer_parameters: Dict = None, return_probas: bool = False, attention_probs_keep_prob: Optional[float] = None, hidden_keep_prob: Optional[float] = None, clip_norm: Optional[float] = None, threshold: Optional[float] = None, device: str = "cpu", **kwargs ) -> None: """ Transformer-based model on PyTorch for relation extraction. It predicts a relation hold between entities in a text sample (one or several sentences). Args: n_classes: number of output classes model_name: the model which will be used for extracting the relations num_ner_tags: number of NER tags pretrained_bert: key title of pretrained Bert model (e.g. "bert-base-uncased") criterion: criterion name from `torch.nn` optimizer: optimizer name from `torch.optim` optimizer_parameters: dictionary with optimizer's parameters return_probas: set this to `True` if you need the probabilities instead of raw answers attention_probs_keep_prob: keep_prob for Bert self-attention layers hidden_keep_prob: keep_prob for Bert hidden layers clip_norm: clip gradients by norm threshold: manually set value for defining the positively predicted classes (instead of adaptive one) device: cpu/gpu device to use for training the model """ self.n_classes = n_classes self.num_ner_tags = num_ner_tags self.pretrained_bert = pretrained_bert self.return_probas = return_probas self.attention_probs_keep_prob = attention_probs_keep_prob self.hidden_keep_prob = hidden_keep_prob self.clip_norm = clip_norm self.threshold = threshold self.device = device if self.n_classes == 0: raise ConfigError("Please provide a valid number of classes.") if optimizer_parameters is None: optimizer_parameters = {"lr": 5e-5, "weight_decay": 0.01, "eps": 1e-6} super().__init__( n_classes=n_classes, model_name=model_name, optimizer=optimizer, criterion=criterion, optimizer_parameters=optimizer_parameters, return_probas=return_probas, device=self.device, **kwargs)
[docs] def train_on_batch( self, input_ids: List, attention_mask: List, entity_pos: List, entity_tags: List, labels: List ) -> float: """ Trains the relation extraction BERT model on the given batch. Returns: dict with loss and learning rate values. """ _input = { 'input_ids': torch.LongTensor(input_ids).to(self.device), 'attention_mask': torch.LongTensor(attention_mask).to(self.device), 'entity_pos': entity_pos, 'ner_tags': entity_tags, 'labels': labels } self.model.train() self.model.zero_grad() self.optimizer.zero_grad() # zero the parameter gradients hidden_states = self.model(**_input) loss = hidden_states[0] loss.backward() self.optimizer.step() # Clip the norm of the gradients to prevent the "exploding gradients" problem if self.clip_norm: torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.clip_norm) if self.lr_scheduler is not None: self.lr_scheduler.step() return loss.item()
[docs] def __call__( self, input_ids: List, attention_mask: List, entity_pos: List, entity_tags: List ) -> Union[List[int], List[np.ndarray]]: """ Get model predictions using features as input """ self.model.eval() _input = { 'input_ids': torch.LongTensor(input_ids).to(self.device), 'attention_mask': torch.LongTensor(attention_mask).to(self.device), 'entity_pos': entity_pos, 'ner_tags': entity_tags } with torch.no_grad(): indices, probas = self.model(**_input) if self.return_probas: pred = probas.cpu().numpy() pred[np.isnan(pred)] = 0 pred_without_no_rel = [] # eliminate no_relation predictions for elem in pred: elem[0] = 0.0 pred_without_no_rel.append(elem) new_pred = np.argmax(pred_without_no_rel, axis=1) one_hot = [[0.0] * self.n_classes] * len(new_pred) for i in range(len(new_pred)): one_hot[i][new_pred[i]] = 1.0 pred = np.array(one_hot) else: pred = indices.cpu().numpy() pred[np.isnan(pred)] = 0 return pred
def re_model(self, **kwargs) -> nn.Module: """ BERT tokenizer -> Input features -> BERT (self.model) -> hidden states -> taking the mean of entities; bilinear formula -> return the whole model. model <= BERT + additional processing """ return BertWithAdaThresholdLocContextPooling( n_classes=self.n_classes, pretrained_bert=self.pretrained_bert, bert_tokenizer_config_file=self.pretrained_bert, num_ner_tags=self.num_ner_tags, threshold=self.threshold, device=self.device ) def collate_fn(self, batch: List[Dict]) -> Tuple[Tensor, Tensor, List, List, List]: input_ids = torch.tensor([f["input_ids"] for f in batch], dtype=torch.long) label = [f["label"] for f in batch] entity_pos = [f["entity_pos"] for f in batch] ner_tags = [f["ner_tags"] for f in batch] attention_mask = torch.tensor([f["attention_mask"] for f in batch], dtype=torch.float) out = (input_ids, attention_mask, entity_pos, ner_tags, label) return out