Source code for

# Copyright 2017 Neural Networks and Deep Learning lab, MIPT
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# See the License for the specific language governing permissions and
# limitations under the License.

from collections import Counter, defaultdict
from itertools import chain
from logging import getLogger
from typing import Iterable, Optional, Tuple

import numpy as np

from deeppavlov.core.common.errors import ConfigError
from deeppavlov.core.common.registry import register
from import zero_pad, is_str_batch, flatten_str_batch
from deeppavlov.core.models.estimator import Estimator

log = getLogger(__name__)

[docs]@register('simple_vocab') class SimpleVocabulary(Estimator): """Implements simple vocabulary. Parameters: special_tokens: tuple of tokens that shouldn't be counted. max_tokens: upper bound for number of tokens in the vocabulary. min_freq: minimal count of a token (except special tokens). pad_with_zeros: if True, then batch of elements will be padded with zeros up to length of the longest element in batch. unk_token: label assigned to unknown tokens. freq_drop_load: if True, then frequencies of tokens are set to min_freq on the model load. """ def __init__(self, special_tokens: Tuple[str, ...] = tuple(), max_tokens: int = 2 ** 30, min_freq: int = 0, pad_with_zeros: bool = False, unk_token: Optional[str] = None, freq_drop_load: Optional[bool] = None, *args, **kwargs): super().__init__(**kwargs) self.special_tokens = special_tokens self._max_tokens = max_tokens self._min_freq = min_freq self._pad_with_zeros = pad_with_zeros self.unk_token = unk_token self.freq_drop_load = freq_drop_load self.reset() if self.load_path: self.load() def fit(self, *args): self.reset() tokens = chain(*args) # filter(None, <>) -- to filter empty tokens self.freqs = Counter(filter(None, flatten_str_batch(tokens))) for special_token in self.special_tokens: self._t2i[special_token] = self.count self._i2t.append(special_token) self.count += 1 for token, freq in self.freqs.most_common()[:self._max_tokens]: if token in self.special_tokens: continue if freq >= self._min_freq: self._t2i[token] = self.count self._i2t.append(token) self.count += 1 def _add_tokens_with_freqs(self, tokens, freqs): self.freqs = Counter() self.freqs.update(dict(zip(tokens, freqs))) for token, freq in zip(tokens, freqs): if freq >= self._min_freq or token in self.special_tokens: self._t2i[token] = self.count self._i2t.append(token) self.count += 1 def __call__(self, batch, is_top=True, **kwargs): if isinstance(batch, Iterable) and not isinstance(batch, str): if all([k is None for k in batch]): return batch else: looked_up_batch = [self(sample, is_top=False) for sample in batch] else: return self[batch] if self._pad_with_zeros and is_top and not is_str_batch(looked_up_batch): looked_up_batch = zero_pad(looked_up_batch) return looked_up_batch def save(self):"[saving vocabulary to {}]".format(self.save_path)) with'wt', encoding='utf8') as f: for n in range(len(self)): token = self._i2t[n] cnt = self.freqs[token] f.write('{}\t{:d}\n'.format(token, cnt)) def load(self): self.reset() if self.load_path: if self.load_path.is_file(): log.debug("[loading vocabulary from {}]".format(self.load_path)) tokens, counts = [], [] for ln in'r', encoding='utf8'): token, cnt = self.load_line(ln) tokens.append(token) counts.append(int(cnt)) self._add_tokens_with_freqs(tokens, counts) elif not self.load_path.parent.is_dir(): raise ConfigError("Provided `load_path` for {} doesn't exist!".format( self.__class__.__name__)) else: raise ConfigError("`load_path` for {} is not provided!".format(self)) def load_line(self, ln): if self.freq_drop_load: token = ln.strip().split()[0] cnt = self._min_freq else: token, cnt = ln.rsplit('\t', 1) return token, cnt @property def len(self): return len(self) def keys(self): return (self[n] for n in range(self.len)) def values(self): return list(range(self.len)) def items(self): return zip(self.keys(), self.values()) def __getitem__(self, key): if isinstance(key, (int, np.integer)): return self._i2t[key] elif isinstance(key, str): return self._t2i[key] else: raise NotImplementedError("not implemented for type `{}`".format(type(key))) def __contains__(self, item): return item in self._t2i def __len__(self): return len(self._i2t) def reset(self): self.freqs = None unk_index = 0 if self.unk_token in self.special_tokens: unk_index = self.special_tokens.index(self.unk_token) self._t2i = defaultdict(lambda: unk_index) self._i2t = [] self.count = 0 def idxs2toks(self, idxs): return [self[idx] for idx in idxs]