Source code for deeppavlov.models.embedders.glove_embedder

# Copyright 2017 Neural Networks and Deep Learning lab, MIPT
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import pickle
from logging import getLogger
from typing import Iterator

import numpy as np
from gensim.models import KeyedVectors
from overrides import overrides

from deeppavlov.core.common.registry import register
from deeppavlov.models.embedders.abstract_embedder import Embedder

log = getLogger(__name__)


[docs]@register('glove') class GloVeEmbedder(Embedder): """ Class implements GloVe embedding model Args: load_path: path where to load pre-trained embedding model from pad_zero: whether to pad samples or not Attributes: model: GloVe model instance tok2emb: dictionary with already embedded tokens dim: dimension of embeddings pad_zero: whether to pad sequence of tokens with zeros or not load_path: path with pre-trained GloVe model """ def _get_word_vector(self, w: str) -> np.ndarray: return self.model[w] def load(self) -> None: """ Load dict of embeddings from given file """ log.info(f"[loading GloVe embeddings from `{self.load_path}`]") if not self.load_path.exists(): log.warning(f'{self.load_path} does not exist, cannot load embeddings from it!') return self.model = KeyedVectors.load_word2vec_format(str(self.load_path)) self.dim = self.model.vector_size
[docs] @overrides def __iter__(self) -> Iterator[str]: """ Iterate over all words from GloVe model vocabulary Returns: iterator """ yield from self.model.vocab
def serialize(self) -> bytes: return pickle.dumps(self.model, protocol=4) def deserialize(self, data: bytes) -> None: self.model = pickle.loads(data) self.dim = self.model.vector_size