deeppavlov.models.relation_extraction

class deeppavlov.models.relation_extraction.relation_extraction_bert.REBertModel(n_classes: int, num_ner_tags: int, pretrained_bert: Optional[str] = None, return_probas: bool = False, threshold: Optional[float] = None, **kwargs)[source]
__init__(n_classes: int, num_ner_tags: int, pretrained_bert: Optional[str] = None, return_probas: bool = False, threshold: Optional[float] = None, **kwargs) None[source]

Transformer-based model on PyTorch for relation extraction. It predicts a relation hold between entities in a text sample (one or several sentences). :param n_classes: number of output classes :param num_ner_tags: number of NER tags :param pretrained_bert: key title of pretrained Bert model (e.g. “bert-base-uncased”) :param return_probas: set this to True if you need the probabilities instead of raw answers :param threshold: manually set value for defining the positively predicted classes (instead of adaptive one)

__call__(input_ids: List, attention_mask: List, entity_pos: List, entity_tags: List) Union[List[int], List[ndarray]][source]

Get model predictions using features as input

train_on_batch(input_ids: List, attention_mask: List, entity_pos: List, entity_tags: List, labels: List) float[source]

Trains the relation extraction BERT model on the given batch. :returns: dict with loss and learning rate values.