Source code for deeppavlov.models.entity_extraction.ner_chunker

# Copyright 2017 Neural Networks and Deep Learning lab, MIPT
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import re
from logging import getLogger
from string import punctuation
from typing import List, Tuple, Union, Any

from nltk import sent_tokenize
from transformers import AutoTokenizer

from deeppavlov.core.common.registry import register
from deeppavlov.core.models.component import Component
from deeppavlov.core.common.chainer import Chainer
from deeppavlov.models.entity_extraction.entity_detection_parser import EntityDetectionParser

log = getLogger(__name__)


[docs]@register('ner_chunker') class NerChunker(Component): """ Class to split documents into chunks of max_seq_len symbols so that the length will not exceed maximal sequence length to feed into BERT """
[docs] def __init__(self, vocab_file: str, max_seq_len: int = 400, lowercase: bool = False, batch_size: int = 2, **kwargs): """ Args: vocab_file: vocab file of pretrained transformer model max_seq_len: maximal length of chunks into which the document is split lowercase: whether to lowercase text batch_size: how many chunks are in batch """ self.max_seq_len = max_seq_len self.batch_size = batch_size self.re_tokenizer = re.compile(r"[\w']+|[^\w ]") self.tokenizer = AutoTokenizer.from_pretrained(vocab_file, do_lower_case=True) self.punct_ext = punctuation + " " + "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789" self.russian_letters = "абвгдеёжзийклмнопрстуфхцчшщъыьэюя" self.lowercase = lowercase
[docs] def __call__(self, docs_batch: List[str]) -> Tuple[List[List[str]], List[List[int]], List[List[Union[ List[Union[Tuple[int, int], Tuple[Union[int, Any], Union[int, Any]]]], List[ Tuple[Union[int, Any], Union[int, Any]]], List[Tuple[int, int]]]]], List[List[Union[List[Any], List[str]]]], List[List[str]]]: """ This method splits each document in the batch into chunks wuth the maximal length of max_seq_len Args: docs_batch: batch of documents Returns: batch of lists of document chunks for each document batch of lists of numbers of documents which correspond to chunks """ text_batch_list, nums_batch_list, sentences_offsets_batch_list, sentences_batch_list = [], [], [], [] text_batch, nums_batch, sentences_offsets_batch, sentences_batch = [], [], [], [] for n, doc in enumerate(docs_batch): if self.lowercase: doc = doc.lower() start = 0 text = "" sentences_list = [] sentences_offsets_list = [] cur_len = 0 doc_pieces = doc.split("\n") doc_pieces = [self.sanitize(doc_piece) for doc_piece in doc_pieces] doc_pieces = [doc_piece for doc_piece in doc_pieces if len(doc_piece) > 1] if doc_pieces: sentences = [] for doc_piece in doc_pieces: sentences += sent_tokenize(doc_piece) for sentence in sentences: sentence_tokens = re.findall(self.re_tokenizer, sentence) sentence_len = sum([len(self.tokenizer.encode_plus(token, add_special_tokens=False)["input_ids"]) for token in sentence_tokens]) if cur_len + sentence_len < self.max_seq_len: text += f"{sentence} " cur_len += sentence_len end = start + len(sentence) sentences_offsets_list.append((start, end)) sentences_list.append(sentence) start = end + 1 else: text = text.strip() if text: text_batch.append(text) sentences_offsets_batch.append(sentences_offsets_list) sentences_batch.append(sentences_list) nums_batch.append(n) if sentence_len < self.max_seq_len: text = f"{sentence} " cur_len = sentence_len start = 0 end = start + len(sentence) sentences_offsets_list = [(start, end)] sentences_list = [sentence] start = end + 1 else: text = "" sentence_chunks = sentence.split(" ") for chunk in sentence_chunks: chunk_tokens = re.findall(self.re_tokenizer, chunk) chunk_len = sum([len(self.tokenizer.encode_plus(token, add_special_tokens=False)["input_ids"]) for token in chunk_tokens]) if cur_len + chunk_len < self.max_seq_len: text += f"{chunk} " cur_len += chunk_len + 1 end = start + len(chunk) sentences_offsets_list.append((start, end)) sentences_list.append(chunk) start = end + 1 else: text = text.strip() if text: text_batch.append(text) sentences_offsets_batch.append(sentences_offsets_list) sentences_batch.append(sentences_list) nums_batch.append(n) text = f"{chunk} " cur_len = chunk_len start = 0 end = start + len(chunk) sentences_offsets_list = [(start, end)] sentences_list = [chunk] start = end + 1 text = text.strip().strip(",") if text: text_batch.append(text) nums_batch.append(n) sentences_offsets_batch.append(sentences_offsets_list) sentences_batch.append(sentences_list) else: text_batch.append("а") nums_batch.append(n) sentences_offsets_batch.append([(0, len(doc))]) sentences_batch.append([doc]) num_batches = len(text_batch) // self.batch_size + int(len(text_batch) % self.batch_size > 0) for jj in range(num_batches): text_batch_list.append(text_batch[jj * self.batch_size:(jj + 1) * self.batch_size]) nums_batch_list.append(nums_batch[jj * self.batch_size:(jj + 1) * self.batch_size]) sentences_offsets_batch_list.append( sentences_offsets_batch[jj * self.batch_size:(jj + 1) * self.batch_size]) sentences_batch_list.append(sentences_batch[jj * self.batch_size:(jj + 1) * self.batch_size]) return text_batch_list, nums_batch_list, sentences_offsets_batch_list, sentences_batch_list
def sanitize(self, text): text_len = len(text) if text_len > 0 and text[text_len - 1] not in {'.', '!', '?'}: i = text_len - 1 while text[i] in self.punct_ext and i > 0: i -= 1 if (text[i] in {'.', '!', '?'} and text[i - 1].lower() in self.russian_letters) or \ (i > 1 and text[i] in {'.', '!', '?'} and text[i - 1] in '"' and text[ i - 2].lower() in self.russian_letters): break text = text[:i + 1] text = re.sub(r'\s+', ' ', text) return text
@register('ner_chunk_model') class NerChunkModel(Component): """ Class for linking of entity substrings in the document to entities in Wikidata """ def __init__(self, ner: Chainer, ner_parser: EntityDetectionParser, ner2: Chainer = None, ner_parser2: EntityDetectionParser = None, **kwargs) -> None: """ Args: ner: config for entity detection ner_parser: component deeppavlov.models.entity_extraction.entity_detection_parser ner2: config of additional entity detection model (ensemble of ner and ner2 models gives better entity detection quality than single ner model) ner_parser2: component deeppavlov.models.entity_extraction.entity_detection_parser **kwargs: """ self.ner = ner self.ner_parser = ner_parser self.ner2 = ner2 self.ner_parser2 = ner_parser2 def __call__(self, text_batch_list: List[List[str]], nums_batch_list: List[List[int]], sentences_offsets_batch_list: List[List[List[Tuple[int, int]]]], sentences_batch_list: List[List[List[str]]] ): """ Args: text_batch_list: list of document chunks nums_batch_list: nums of documents sentences_offsets_batch_list: indices of start and end symbols of sentences in text sentences_batch_list: list of sentences from texts Returns: doc_entity_substr_batch: entity substrings doc_entity_offsets_batch: indices of start and end symbols of entities in text doc_tags_batch: entity tags (PER, LOC, ORG) doc_sentences_offsets_batch: indices of start and end symbols of sentences in text doc_sentences_batch: list of sentences from texts """ entity_substr_batch_list, entity_offsets_batch_list, entity_positions_batch_list, tags_batch_list, \ entity_probas_batch_list, text_len_batch_list, text_tokens_len_batch_list = [], [], [], [], [], [], [] for text_batch, sentences_offsets_batch, sentences_batch in \ zip(text_batch_list, sentences_offsets_batch_list, sentences_batch_list): text_batch = [text.replace("\xad", " ") for text in text_batch] ner_tokens_batch, ner_tokens_offsets_batch, ner_probas_batch, probas_batch = self.ner(text_batch) entity_substr_batch, entity_positions_batch, entity_probas_batch = \ self.ner_parser(ner_tokens_batch, ner_probas_batch, probas_batch) if self.ner2: ner_tokens_batch2, ner_tokens_offsets_batch2, ner_probas_batch2, probas_batch2 = self.ner2(text_batch) entity_substr_batch2, entity_positions_batch2, entity_probas_batch2 = \ self.ner_parser2(ner_tokens_batch2, ner_probas_batch2, probas_batch2) entity_substr_batch, entity_positions_batch, entity_probas_batch = \ self.merge_annotations(entity_substr_batch, entity_positions_batch, entity_probas_batch, entity_substr_batch2, entity_positions_batch2, entity_probas_batch2) entity_pos_tags_probas_batch = [[(entity_substr.lower(), entity_substr_positions, tag, entity_proba) for tag, entity_substr_list in entity_substr_dict.items() for entity_substr, entity_substr_positions, entity_proba in zip(entity_substr_list, entity_positions_dict[tag], entity_probas_dict[tag])] for entity_substr_dict, entity_positions_dict, entity_probas_dict in zip(entity_substr_batch, entity_positions_batch, entity_probas_batch)] entity_substr_batch, entity_offsets_batch, entity_positions_batch, tags_batch, \ probas_batch = [], [], [], [], [] for entity_pos_tags_probas, ner_tokens_offsets_list in \ zip(entity_pos_tags_probas_batch, ner_tokens_offsets_batch): if entity_pos_tags_probas: entity_offsets_list = [] entity_substr_list, entity_positions_list, tags_list, probas_list = zip(*entity_pos_tags_probas) for entity_positions in entity_positions_list: start_offset = ner_tokens_offsets_list[entity_positions[0]][0] end_offset = ner_tokens_offsets_list[entity_positions[-1]][1] entity_offsets_list.append((start_offset, end_offset)) else: entity_substr_list, entity_offsets_list, entity_positions_list = [], [], [] tags_list, probas_list = [], [] entity_substr_batch.append(list(entity_substr_list)) entity_offsets_batch.append(list(entity_offsets_list)) entity_positions_batch.append(list(entity_positions_list)) tags_batch.append(list(tags_list)) probas_batch.append(list(probas_list)) entity_substr_batch_list.append(entity_substr_batch) tags_batch_list.append(tags_batch) entity_offsets_batch_list.append(entity_offsets_batch) entity_positions_batch_list.append(entity_positions_batch) entity_probas_batch_list.append(probas_batch) text_len_batch_list.append([len(text) for text in text_batch]) text_tokens_len_batch_list.append([len(ner_tokens) for ner_tokens in ner_tokens_batch]) doc_entity_substr_batch, doc_tags_batch, doc_entity_offsets_batch, doc_probas_batch = [], [], [], [] doc_entity_positions_batch, doc_sentences_offsets_batch, doc_sentences_batch = [], [], [] doc_entity_substr, doc_tags, doc_probas, doc_entity_offsets, doc_entity_positions = [], [], [], [], [] doc_sentences_offsets, doc_sentences = [], [] cur_doc_num = 0 text_len_sum = 0 text_tokens_len_sum = 0 for entity_substr_batch, tags_batch, probas_batch, entity_offsets_batch, entity_positions_batch, \ sentences_offsets_batch, sentences_batch, text_len_batch, text_tokens_len_batch, nums_batch in \ zip(entity_substr_batch_list, tags_batch_list, entity_probas_batch_list, entity_offsets_batch_list, entity_positions_batch_list, sentences_offsets_batch_list, sentences_batch_list, text_len_batch_list, text_tokens_len_batch_list, nums_batch_list): for entity_substr_list, tag_list, probas_list, entity_offsets_list, entity_positions_list, \ sentences_offsets_list, sentences_list, text_len, text_tokens_len, doc_num in \ zip(entity_substr_batch, tags_batch, probas_batch, entity_offsets_batch, entity_positions_batch, sentences_offsets_batch, sentences_batch, text_len_batch, text_tokens_len_batch, nums_batch): if doc_num == cur_doc_num: doc_entity_substr += entity_substr_list doc_tags += tag_list doc_probas += probas_list doc_entity_offsets += [(start_offset + text_len_sum, end_offset + text_len_sum) for start_offset, end_offset in entity_offsets_list] doc_sentences_offsets += [(start_offset + text_len_sum, end_offset + text_len_sum) for start_offset, end_offset in sentences_offsets_list] doc_entity_positions += [[pos + text_tokens_len_sum for pos in positions] for positions in entity_positions_list] doc_sentences += sentences_list text_len_sum += text_len + 1 text_tokens_len_sum += text_tokens_len else: doc_entity_substr_batch.append(doc_entity_substr) doc_tags_batch.append(doc_tags) doc_probas_batch.append(doc_probas) doc_entity_offsets_batch.append(doc_entity_offsets) doc_entity_positions_batch.append(doc_entity_positions) doc_sentences_offsets_batch.append(doc_sentences_offsets) doc_sentences_batch.append(doc_sentences) doc_entity_substr = entity_substr_list doc_tags = tag_list doc_probas = probas_list doc_entity_offsets = entity_offsets_list doc_sentences_offsets = sentences_offsets_list doc_sentences = sentences_list cur_doc_num = doc_num text_len_sum = text_len + 1 text_tokens_len_sum = text_tokens_len doc_entity_substr_batch.append(doc_entity_substr) doc_tags_batch.append(doc_tags) doc_probas_batch.append(doc_probas) doc_entity_offsets_batch.append(doc_entity_offsets) doc_entity_positions_batch.append(doc_entity_positions) doc_sentences_offsets_batch.append(doc_sentences_offsets) doc_sentences_batch.append(doc_sentences) return doc_entity_substr_batch, doc_entity_offsets_batch, doc_entity_positions_batch, doc_tags_batch, \ doc_sentences_offsets_batch, doc_sentences_batch, doc_probas_batch def merge_annotations(self, substr_batch, pos_batch, probas_batch, substr_batch2, pos_batch2, probas_batch2): log.debug(f"ner_chunker, substr2: {substr_batch2} --- pos2: {pos_batch2} --- probas2: {probas_batch2} --- " f"substr: {substr_batch} --- pos: {pos_batch} --- probas: {probas_batch}") for i in range(len(substr_batch)): for key2 in substr_batch2[i]: substr_list2 = substr_batch2[i][key2] pos_list2 = pos_batch2[i][key2] probas_list2 = probas_batch2[i][key2] for substr2, pos2, probas2 in zip(substr_list2, pos_list2, probas_list2): found = False for key in substr_batch[i]: pos_list = pos_batch[i][key] for pos in pos_list: if pos[0] <= pos2[0] <= pos[-1] or pos[0] <= pos2[-1] <= pos[-1]: found = True if not found: if key2 not in substr_batch[i]: substr_batch[i][key2] = [] pos_batch[i][key2] = [] probas_batch[i][key2] = [] substr_batch[i][key2].append(substr2) pos_batch[i][key2].append(pos2) probas_batch[i][key2].append(probas2) for i in range(len(substr_batch)): for key2 in substr_batch2[i]: substr_list2 = substr_batch2[i][key2] pos_list2 = pos_batch2[i][key2] probas_list2 = probas_batch2[i][key2] for substr2, pos2, probas2 in zip(substr_list2, pos_list2, probas_list2): for key in substr_batch[i]: inds = [] substr_list = substr_batch[i][key] pos_list = pos_batch[i][key] probas_list = probas_batch[i][key] for n, (substr, pos, probas) in enumerate(zip(substr_list, pos_list, probas_list)): if (pos[0] == pos2[0] and pos[-1] < pos2[-1]) or (pos[0] > pos2[0] and pos[-1] == pos2[-1]): inds.append(n) elif key == "EVENT" and ((pos[0] >= pos2[0] and pos[-1] <= pos2[-1]) or (len(substr.split()) == 1 and pos2[0] <= pos[0])): inds.append(n) if (len(inds) > 1 or (len(inds) == 1 and key in {"WORK_OF_ART", "EVENT"})) \ and not (key == "PERSON" and " и " in substr2): inds = sorted(inds, reverse=True) for ind in inds: del substr_batch[i][key][ind] del pos_batch[i][key][ind] del probas_batch[i][key][ind] substr_batch[i][key].append(substr2) pos_batch[i][key].append(pos2) probas_batch[i][key].append(probas2) return substr_batch, pos_batch, probas_batch