# Copyright 2017 Neural Networks and Deep Learning lab, MIPT
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import re
from logging import getLogger
from string import punctuation
from typing import List, Tuple, Union, Any
from nltk import sent_tokenize
from transformers import AutoTokenizer
from deeppavlov.core.common.registry import register
from deeppavlov.core.models.component import Component
from deeppavlov.core.common.chainer import Chainer
from deeppavlov.models.entity_extraction.entity_detection_parser import EntityDetectionParser
log = getLogger(__name__)
[docs]@register('ner_chunker')
class NerChunker(Component):
"""
Class to split documents into chunks of max_seq_len symbols so that the length will not exceed
maximal sequence length to feed into BERT
"""
[docs] def __init__(self, vocab_file: str, max_seq_len: int = 400, lowercase: bool = False, batch_size: int = 2, **kwargs):
"""
Args:
vocab_file: vocab file of pretrained transformer model
max_seq_len: maximal length of chunks into which the document is split
lowercase: whether to lowercase text
batch_size: how many chunks are in batch
"""
self.max_seq_len = max_seq_len
self.batch_size = batch_size
self.re_tokenizer = re.compile(r"[\w']+|[^\w ]")
self.tokenizer = AutoTokenizer.from_pretrained(vocab_file,
do_lower_case=True)
self.punct_ext = punctuation + " " + "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789"
self.russian_letters = "абвгдеёжзийклмнопрстуфхцчшщъыьэюя"
self.lowercase = lowercase
[docs] def __call__(self, docs_batch: List[str]) -> Tuple[List[List[str]], List[List[int]], List[List[Union[
List[Union[Tuple[int, int], Tuple[Union[int, Any], Union[int, Any]]]], List[
Tuple[Union[int, Any], Union[int, Any]]], List[Tuple[int, int]]]]], List[List[Union[List[Any], List[str]]]],
List[List[str]]]:
"""
This method splits each document in the batch into chunks wuth the maximal length of max_seq_len
Args:
docs_batch: batch of documents
Returns:
batch of lists of document chunks for each document
batch of lists of numbers of documents which correspond to chunks
"""
text_batch_list, nums_batch_list, sentences_offsets_batch_list, sentences_batch_list = [], [], [], []
text_batch, nums_batch, sentences_offsets_batch, sentences_batch = [], [], [], []
for n, doc in enumerate(docs_batch):
if self.lowercase:
doc = doc.lower()
start = 0
text = ""
sentences_list = []
sentences_offsets_list = []
cur_len = 0
doc_pieces = doc.split("\n")
doc_pieces = [self.sanitize(doc_piece) for doc_piece in doc_pieces]
doc_pieces = [doc_piece for doc_piece in doc_pieces if len(doc_piece) > 1]
if doc_pieces:
sentences = []
for doc_piece in doc_pieces:
sentences += sent_tokenize(doc_piece)
for sentence in sentences:
sentence_tokens = re.findall(self.re_tokenizer, sentence)
sentence_len = sum([len(self.tokenizer.encode_plus(token, add_special_tokens=False)["input_ids"])
for token in sentence_tokens])
if cur_len + sentence_len < self.max_seq_len:
text += f"{sentence} "
cur_len += sentence_len
end = start + len(sentence)
sentences_offsets_list.append((start, end))
sentences_list.append(sentence)
start = end + 1
else:
text = text.strip()
if text:
text_batch.append(text)
sentences_offsets_batch.append(sentences_offsets_list)
sentences_batch.append(sentences_list)
nums_batch.append(n)
if sentence_len < self.max_seq_len:
text = f"{sentence} "
cur_len = sentence_len
start = 0
end = start + len(sentence)
sentences_offsets_list = [(start, end)]
sentences_list = [sentence]
start = end + 1
else:
text = ""
sentence_chunks = sentence.split(" ")
for chunk in sentence_chunks:
chunk_tokens = re.findall(self.re_tokenizer, chunk)
chunk_len = sum([len(self.tokenizer.encode_plus(token,
add_special_tokens=False)["input_ids"])
for token in chunk_tokens])
if cur_len + chunk_len < self.max_seq_len:
text += f"{chunk} "
cur_len += chunk_len + 1
end = start + len(chunk)
sentences_offsets_list.append((start, end))
sentences_list.append(chunk)
start = end + 1
else:
text = text.strip()
if text:
text_batch.append(text)
sentences_offsets_batch.append(sentences_offsets_list)
sentences_batch.append(sentences_list)
nums_batch.append(n)
text = f"{chunk} "
cur_len = chunk_len
start = 0
end = start + len(chunk)
sentences_offsets_list = [(start, end)]
sentences_list = [chunk]
start = end + 1
text = text.strip().strip(",")
if text:
text_batch.append(text)
nums_batch.append(n)
sentences_offsets_batch.append(sentences_offsets_list)
sentences_batch.append(sentences_list)
else:
text_batch.append("а")
nums_batch.append(n)
sentences_offsets_batch.append([(0, len(doc))])
sentences_batch.append([doc])
num_batches = len(text_batch) // self.batch_size + int(len(text_batch) % self.batch_size > 0)
for jj in range(num_batches):
text_batch_list.append(text_batch[jj * self.batch_size:(jj + 1) * self.batch_size])
nums_batch_list.append(nums_batch[jj * self.batch_size:(jj + 1) * self.batch_size])
sentences_offsets_batch_list.append(
sentences_offsets_batch[jj * self.batch_size:(jj + 1) * self.batch_size])
sentences_batch_list.append(sentences_batch[jj * self.batch_size:(jj + 1) * self.batch_size])
return text_batch_list, nums_batch_list, sentences_offsets_batch_list, sentences_batch_list
def sanitize(self, text):
text_len = len(text)
if text_len > 0 and text[text_len - 1] not in {'.', '!', '?'}:
i = text_len - 1
while text[i] in self.punct_ext and i > 0:
i -= 1
if (text[i] in {'.', '!', '?'} and text[i - 1].lower() in self.russian_letters) or \
(i > 1 and text[i] in {'.', '!', '?'} and text[i - 1] in '"' and text[
i - 2].lower() in self.russian_letters):
break
text = text[:i + 1]
text = re.sub(r'\s+', ' ', text)
return text
@register('ner_chunk_model')
class NerChunkModel(Component):
"""
Class for linking of entity substrings in the document to entities in Wikidata
"""
def __init__(self, ner: Chainer,
ner_parser: EntityDetectionParser,
ner2: Chainer = None,
ner_parser2: EntityDetectionParser = None,
**kwargs) -> None:
"""
Args:
ner: config for entity detection
ner_parser: component deeppavlov.models.entity_extraction.entity_detection_parser
ner2: config of additional entity detection model (ensemble of ner and ner2 models gives better
entity detection quality than single ner model)
ner_parser2: component deeppavlov.models.entity_extraction.entity_detection_parser
**kwargs:
"""
self.ner = ner
self.ner_parser = ner_parser
self.ner2 = ner2
self.ner_parser2 = ner_parser2
def __call__(self, text_batch_list: List[List[str]],
nums_batch_list: List[List[int]],
sentences_offsets_batch_list: List[List[List[Tuple[int, int]]]],
sentences_batch_list: List[List[List[str]]]
):
"""
Args:
text_batch_list: list of document chunks
nums_batch_list: nums of documents
sentences_offsets_batch_list: indices of start and end symbols of sentences in text
sentences_batch_list: list of sentences from texts
Returns:
doc_entity_substr_batch: entity substrings
doc_entity_offsets_batch: indices of start and end symbols of entities in text
doc_tags_batch: entity tags (PER, LOC, ORG)
doc_sentences_offsets_batch: indices of start and end symbols of sentences in text
doc_sentences_batch: list of sentences from texts
"""
entity_substr_batch_list, entity_offsets_batch_list, entity_positions_batch_list, tags_batch_list, \
entity_probas_batch_list, text_len_batch_list, text_tokens_len_batch_list = [], [], [], [], [], [], []
for text_batch, sentences_offsets_batch, sentences_batch in \
zip(text_batch_list, sentences_offsets_batch_list, sentences_batch_list):
text_batch = [text.replace("\xad", " ") for text in text_batch]
ner_tokens_batch, ner_tokens_offsets_batch, ner_probas_batch, probas_batch = self.ner(text_batch)
entity_substr_batch, entity_positions_batch, entity_probas_batch = \
self.ner_parser(ner_tokens_batch, ner_probas_batch, probas_batch)
if self.ner2:
ner_tokens_batch2, ner_tokens_offsets_batch2, ner_probas_batch2, probas_batch2 = self.ner2(text_batch)
entity_substr_batch2, entity_positions_batch2, entity_probas_batch2 = \
self.ner_parser2(ner_tokens_batch2, ner_probas_batch2, probas_batch2)
entity_substr_batch, entity_positions_batch, entity_probas_batch = \
self.merge_annotations(entity_substr_batch, entity_positions_batch, entity_probas_batch,
entity_substr_batch2, entity_positions_batch2, entity_probas_batch2)
entity_pos_tags_probas_batch = [[(entity_substr.lower(), entity_substr_positions, tag, entity_proba)
for tag, entity_substr_list in entity_substr_dict.items()
for entity_substr, entity_substr_positions, entity_proba in
zip(entity_substr_list, entity_positions_dict[tag],
entity_probas_dict[tag])]
for entity_substr_dict, entity_positions_dict, entity_probas_dict in
zip(entity_substr_batch, entity_positions_batch, entity_probas_batch)]
entity_substr_batch, entity_offsets_batch, entity_positions_batch, tags_batch, \
probas_batch = [], [], [], [], []
for entity_pos_tags_probas, ner_tokens_offsets_list in \
zip(entity_pos_tags_probas_batch, ner_tokens_offsets_batch):
if entity_pos_tags_probas:
entity_offsets_list = []
entity_substr_list, entity_positions_list, tags_list, probas_list = zip(*entity_pos_tags_probas)
for entity_positions in entity_positions_list:
start_offset = ner_tokens_offsets_list[entity_positions[0]][0]
end_offset = ner_tokens_offsets_list[entity_positions[-1]][1]
entity_offsets_list.append((start_offset, end_offset))
else:
entity_substr_list, entity_offsets_list, entity_positions_list = [], [], []
tags_list, probas_list = [], []
entity_substr_batch.append(list(entity_substr_list))
entity_offsets_batch.append(list(entity_offsets_list))
entity_positions_batch.append(list(entity_positions_list))
tags_batch.append(list(tags_list))
probas_batch.append(list(probas_list))
entity_substr_batch_list.append(entity_substr_batch)
tags_batch_list.append(tags_batch)
entity_offsets_batch_list.append(entity_offsets_batch)
entity_positions_batch_list.append(entity_positions_batch)
entity_probas_batch_list.append(probas_batch)
text_len_batch_list.append([len(text) for text in text_batch])
text_tokens_len_batch_list.append([len(ner_tokens) for ner_tokens in ner_tokens_batch])
doc_entity_substr_batch, doc_tags_batch, doc_entity_offsets_batch, doc_probas_batch = [], [], [], []
doc_entity_positions_batch, doc_sentences_offsets_batch, doc_sentences_batch = [], [], []
doc_entity_substr, doc_tags, doc_probas, doc_entity_offsets, doc_entity_positions = [], [], [], [], []
doc_sentences_offsets, doc_sentences = [], []
cur_doc_num = 0
text_len_sum = 0
text_tokens_len_sum = 0
for entity_substr_batch, tags_batch, probas_batch, entity_offsets_batch, entity_positions_batch, \
sentences_offsets_batch, sentences_batch, text_len_batch, text_tokens_len_batch, nums_batch in \
zip(entity_substr_batch_list, tags_batch_list, entity_probas_batch_list, entity_offsets_batch_list,
entity_positions_batch_list, sentences_offsets_batch_list, sentences_batch_list,
text_len_batch_list, text_tokens_len_batch_list, nums_batch_list):
for entity_substr_list, tag_list, probas_list, entity_offsets_list, entity_positions_list, \
sentences_offsets_list, sentences_list, text_len, text_tokens_len, doc_num in \
zip(entity_substr_batch, tags_batch, probas_batch, entity_offsets_batch, entity_positions_batch,
sentences_offsets_batch, sentences_batch, text_len_batch, text_tokens_len_batch, nums_batch):
if doc_num == cur_doc_num:
doc_entity_substr += entity_substr_list
doc_tags += tag_list
doc_probas += probas_list
doc_entity_offsets += [(start_offset + text_len_sum, end_offset + text_len_sum)
for start_offset, end_offset in entity_offsets_list]
doc_sentences_offsets += [(start_offset + text_len_sum, end_offset + text_len_sum)
for start_offset, end_offset in sentences_offsets_list]
doc_entity_positions += [[pos + text_tokens_len_sum for pos in positions]
for positions in entity_positions_list]
doc_sentences += sentences_list
text_len_sum += text_len + 1
text_tokens_len_sum += text_tokens_len
else:
doc_entity_substr_batch.append(doc_entity_substr)
doc_tags_batch.append(doc_tags)
doc_probas_batch.append(doc_probas)
doc_entity_offsets_batch.append(doc_entity_offsets)
doc_entity_positions_batch.append(doc_entity_positions)
doc_sentences_offsets_batch.append(doc_sentences_offsets)
doc_sentences_batch.append(doc_sentences)
doc_entity_substr = entity_substr_list
doc_tags = tag_list
doc_probas = probas_list
doc_entity_offsets = entity_offsets_list
doc_sentences_offsets = sentences_offsets_list
doc_sentences = sentences_list
cur_doc_num = doc_num
text_len_sum = text_len + 1
text_tokens_len_sum = text_tokens_len
doc_entity_substr_batch.append(doc_entity_substr)
doc_tags_batch.append(doc_tags)
doc_probas_batch.append(doc_probas)
doc_entity_offsets_batch.append(doc_entity_offsets)
doc_entity_positions_batch.append(doc_entity_positions)
doc_sentences_offsets_batch.append(doc_sentences_offsets)
doc_sentences_batch.append(doc_sentences)
return doc_entity_substr_batch, doc_entity_offsets_batch, doc_entity_positions_batch, doc_tags_batch, \
doc_sentences_offsets_batch, doc_sentences_batch, doc_probas_batch
def merge_annotations(self, substr_batch, pos_batch, probas_batch, substr_batch2, pos_batch2, probas_batch2):
log.debug(f"ner_chunker, substr2: {substr_batch2} --- pos2: {pos_batch2} --- probas2: {probas_batch2} --- "
f"substr: {substr_batch} --- pos: {pos_batch} --- probas: {probas_batch}")
for i in range(len(substr_batch)):
for key2 in substr_batch2[i]:
substr_list2 = substr_batch2[i][key2]
pos_list2 = pos_batch2[i][key2]
probas_list2 = probas_batch2[i][key2]
for substr2, pos2, probas2 in zip(substr_list2, pos_list2, probas_list2):
found = False
for key in substr_batch[i]:
pos_list = pos_batch[i][key]
for pos in pos_list:
if pos[0] <= pos2[0] <= pos[-1] or pos[0] <= pos2[-1] <= pos[-1]:
found = True
if not found:
if key2 not in substr_batch[i]:
substr_batch[i][key2] = []
pos_batch[i][key2] = []
probas_batch[i][key2] = []
substr_batch[i][key2].append(substr2)
pos_batch[i][key2].append(pos2)
probas_batch[i][key2].append(probas2)
for i in range(len(substr_batch)):
for key2 in substr_batch2[i]:
substr_list2 = substr_batch2[i][key2]
pos_list2 = pos_batch2[i][key2]
probas_list2 = probas_batch2[i][key2]
for substr2, pos2, probas2 in zip(substr_list2, pos_list2, probas_list2):
for key in substr_batch[i]:
inds = []
substr_list = substr_batch[i][key]
pos_list = pos_batch[i][key]
probas_list = probas_batch[i][key]
for n, (substr, pos, probas) in enumerate(zip(substr_list, pos_list, probas_list)):
if (pos[0] == pos2[0] and pos[-1] < pos2[-1]) or (pos[0] > pos2[0] and pos[-1] == pos2[-1]):
inds.append(n)
elif key == "EVENT" and ((pos[0] >= pos2[0] and pos[-1] <= pos2[-1])
or (len(substr.split()) == 1 and pos2[0] <= pos[0])):
inds.append(n)
if (len(inds) > 1 or (len(inds) == 1 and key in {"WORK_OF_ART", "EVENT"})) \
and not (key == "PERSON" and " и " in substr2):
inds = sorted(inds, reverse=True)
for ind in inds:
del substr_batch[i][key][ind]
del pos_batch[i][key][ind]
del probas_batch[i][key][ind]
substr_batch[i][key].append(substr2)
pos_batch[i][key].append(pos2)
probas_batch[i][key].append(probas2)
return substr_batch, pos_batch, probas_batch