deeppavlov.models.relation_extraction

class deeppavlov.models.relation_extraction.relation_extraction_bert.REBertModel(n_classes: int, model_name: str, num_ner_tags: int, pretrained_bert: Optional[str] = None, criterion: str = 'CrossEntropyLoss', optimizer: str = 'AdamW', optimizer_parameters: Optional[Dict] = None, return_probas: bool = False, attention_probs_keep_prob: Optional[float] = None, hidden_keep_prob: Optional[float] = None, clip_norm: Optional[float] = None, threshold: Optional[float] = None, device: str = 'cpu', **kwargs)[source]
__init__(n_classes: int, model_name: str, num_ner_tags: int, pretrained_bert: Optional[str] = None, criterion: str = 'CrossEntropyLoss', optimizer: str = 'AdamW', optimizer_parameters: Optional[Dict] = None, return_probas: bool = False, attention_probs_keep_prob: Optional[float] = None, hidden_keep_prob: Optional[float] = None, clip_norm: Optional[float] = None, threshold: Optional[float] = None, device: str = 'cpu', **kwargs)None[source]

Transformer-based model on PyTorch for relation extraction. It predicts a relation hold between entities in a text sample (one or several sentences). :param n_classes: number of output classes :param model_name: the model which will be used for extracting the relations :param num_ner_tags: number of NER tags :param pretrained_bert: key title of pretrained Bert model (e.g. “bert-base-uncased”) :param criterion: criterion name from torch.nn :param optimizer: optimizer name from torch.optim :param optimizer_parameters: dictionary with optimizer’s parameters :param return_probas: set this to True if you need the probabilities instead of raw answers :param attention_probs_keep_prob: keep_prob for Bert self-attention layers :param hidden_keep_prob: keep_prob for Bert hidden layers :param clip_norm: clip gradients by norm :param threshold: manually set value for defining the positively predicted classes (instead of adaptive one) :param device: cpu/gpu device to use for training the model

__call__(input_ids: List, attention_mask: List, entity_pos: List, entity_tags: List)Union[List[int], List[numpy.ndarray]][source]

Get model predictions using features as input

train_on_batch(input_ids: List, attention_mask: List, entity_pos: List, entity_tags: List, labels: List)float[source]

Trains the relation extraction BERT model on the given batch. :returns: dict with loss and learning rate values.