Source code for deeppavlov.models.slotfill.slotfill

# Copyright 2017 Neural Networks and Deep Learning lab, MIPT
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# See the License for the specific language governing permissions and
# limitations under the License.

import json
from fuzzywuzzy import process
from overrides import overrides

from deeppavlov.core.common.registry import register
from import download
from deeppavlov.core.common.log import get_logger
from deeppavlov.core.models.serializable import Serializable
from deeppavlov.core.models.component import Component

log = get_logger(__name__)

[docs]@register('dstc_slotfilling') class DstcSlotFillingNetwork(Component, Serializable): """Slot filling for DSTC2 task with neural network""" def __init__(self, threshold: float = 0.8, **kwargs): super().__init__(**kwargs) self.threshold = threshold # Check existance of file with slots, slot values, and corrupted (misspelled) slot values self.load() @overrides def __call__(self, tokens_batch, tags_batch, *args, **kwargs): slots = [{}] * len(tokens_batch) m = [i for i, v in enumerate(tokens_batch) if v] if m: tags_batch = [tags_batch[i] for i in m] tokens_batch = [tokens_batch[i] for i in m] for i, tokens, tags in zip(m, tokens_batch, tags_batch): slots[i] = self.predict_slots(tokens, tags) return slots def predict_slots(self, tokens, tags): # For utterance extract named entities and perform normalization for slot filling entities, slots = self._chunk_finder(tokens, tags) slot_values = {} for entity, slot in zip(entities, slots): match, score = self.ner2slot(entity, slot) if score >= self.threshold * 100: slot_values[slot] = match return slot_values def ner2slot(self, input_entity, slot): # Given named entity return normalized slot value if isinstance(input_entity, list): input_entity = ' '.join(input_entity) entities = [] normalized_slot_vals = [] for entity_name in self._slot_vals[slot]: for entity in self._slot_vals[slot][entity_name]: entities.append(entity) normalized_slot_vals.append(entity_name) best_match, score = process.extract(input_entity, entities, limit=2 ** 20)[0] return normalized_slot_vals[entities.index(best_match)], score @staticmethod def _chunk_finder(tokens, tags): # For BIO labeled sequence of tags extract all named entities form tokens prev_tag = '' chunk_tokens = [] entities = [] slots = [] for token, tag in zip(tokens, tags): curent_tag = tag.split('-')[-1].strip() current_prefix = tag.split('-')[0] if tag.startswith('B-'): if len(chunk_tokens) > 0: entities.append(' '.join(chunk_tokens)) slots.append(prev_tag) chunk_tokens = [] chunk_tokens.append(token) if current_prefix == 'I': if curent_tag != prev_tag: if len(chunk_tokens) > 0: entities.append(' '.join(chunk_tokens)) slots.append(prev_tag) chunk_tokens = [] else: chunk_tokens.append(token) if current_prefix == 'O': if len(chunk_tokens) > 0: entities.append(' '.join(chunk_tokens)) slots.append(prev_tag) chunk_tokens = [] prev_tag = curent_tag if len(chunk_tokens) > 0: entities.append(' '.join(chunk_tokens)) slots.append(prev_tag) return entities, slots def _download_slot_vals(self): url = '' download(self.save_path, url) def save(self, *args, **kwargs): with open(self.save_path, 'w', encoding='utf8') as f: json.dump(self._slot_vals, f) def load(self, *args, **kwargs): if not self.load_path.exists(): self._download_slot_vals() with open(self.load_path, encoding='utf8') as f: self._slot_vals = json.load(f)