from abc import ABCMeta, abstractmethod
import numpy as np
from deeppavlov.core.commands.utils import expand_path
from keras.preprocessing.sequence import pad_sequences
from deeppavlov.core.common.log import get_logger
log = get_logger(__name__)
[docs]class RankingDict(metaclass=ABCMeta):
"""Class to encode characters, tokens, whole contexts and responses with vocabularies, to pad and truncate.
Args:
save_path: A path including filename to store the instance of
:class:`deeppavlov.models.ranking.ranking_network.RankingNetwork`.
load_path: A path including filename to load the instance of
:class:`deeppavlov.models.ranking.ranking_network.RankingNetwork`.
max_sequence_length: A maximum length of a sequence in tokens.
Longer sequences will be truncated and shorter ones will be padded.
tok_dynamic_batch: Whether to use dynamic batching. If ``True``, a maximum length of a sequence for a batch
will be equal to the maximum of all sequences lengths from this batch,
but not higher than ``max_sequence_length``.
padding: Padding. Possible values are ``pre`` and ``post``.
If set to ``pre`` a sequence will be padded at the beginning.
If set to ``post`` it will padded at the end.
truncating: Truncating. Possible values are ``pre`` and ``post``.
If set to ``pre`` a sequence will be truncated at the beginning.
If set to ``post`` it will truncated at the end.
max_token_length: A maximum length of a token for representing it by a character-level embedding.
char_dynamic_batch: Whether to use dynamic batching for character-level embeddings.
If ``True``, a maximum length of a token for a batch
will be equal to the maximum of all tokens lengths from this batch,
but not higher than ``max_token_length``.
char_pad: Character-level padding. Possible values are ``pre`` and ``post``.
If set to ``pre`` a token will be padded at the beginning.
If set to ``post`` it will padded at the end.
char_trunc: Character-level truncating. Possible values are ``pre`` and ``post``.
If set to ``pre`` a token will be truncated at the beginning.
If set to ``post`` it will truncated at the end.
update_embeddings: Whether to store and update context and response embeddings or not.
"""
def __init__(self,
save_path: str,
load_path: str,
max_sequence_length: int,
max_token_length: int,
padding: str = 'post',
truncating: str = 'post',
token_embeddings: bool = True,
char_embeddings: bool = False,
char_pad: str = 'post',
char_trunc: str = 'post',
tok_dynamic_batch: bool = False,
char_dynamic_batch: bool = False,
update_embeddings: bool = False):
self.max_sequence_length = max_sequence_length
self.token_embeddings = token_embeddings
self.char_embeddings = char_embeddings
self.max_token_length = max_token_length
self.padding = padding
self.truncating = truncating
self.char_pad = char_pad
self.char_trunc = char_trunc
self.tok_dynamic_batch = tok_dynamic_batch
self.char_dynamic_batch = char_dynamic_batch
self.upd_embs = update_embeddings
save_path = expand_path(save_path).resolve().parent
load_path = expand_path(load_path).resolve().parent
self.char_save_path = save_path / "char2int.dict"
self.char_load_path = load_path / "char2int.dict"
self.tok_save_path = save_path / "tok2int.dict"
self.tok_load_path = load_path / "tok2int.dict"
self.cont_save_path = save_path / "cont2toks.dict"
self.cont_load_path = load_path / "cont2toks.dict"
self.resp_save_path = save_path / "resp2toks.dict"
self.resp_load_path = load_path / "resp2toks.dict"
self.cemb_save_path = str(save_path / "context_embs.npy")
self.cemb_load_path = str(load_path / "context_embs.npy")
self.remb_save_path = str(save_path / "response_embs.npy")
self.remb_load_path = str(load_path / "response_embs.npy")
self.int2tok_vocab = {}
self.tok2int_vocab = {}
self.response2toks_vocab = {}
self.response2emb_vocab = {}
self.context2toks_vocab = {}
self.context2emb_vocab = {}
def init_from_scratch(self):
log.info("[initializing new `{}`]".format(self.__class__.__name__))
if self.char_embeddings:
self.build_int2char_vocab()
self.build_char2int_vocab()
self.build_int2tok_vocab()
self.build_tok2int_vocab()
self.build_context2toks_vocabulary()
self.build_response2toks_vocabulary()
if self.upd_embs:
self.build_context2emb_vocabulary()
self.build_response2emb_vocabulary()
def load(self):
log.info("[initializing `{}` from saved]".format(self.__class__.__name__))
if self.char_embeddings:
self.load_int2char()
self.build_char2int_vocab()
self.load_int2tok()
self.build_tok2int_vocab()
self.load_context2toks()
self.load_response2toks()
if self.upd_embs:
self.load_cont()
self.load_resp()
def save(self):
log.info("[saving `{}`]".format(self.__class__.__name__))
if self.char_embeddings:
self.save_int2char()
self.save_int2tok()
self.save_context2toks()
self.save_response2toks()
if self.upd_embs:
self.save_cont()
self.save_resp()
@abstractmethod
def build_int2char_vocab(self):
pass
@abstractmethod
def build_int2tok_vocab(self):
pass
@abstractmethod
def build_response2toks_vocabulary(self):
pass
@abstractmethod
def build_context2toks_vocabulary(self):
pass
def build_char2int_vocab(self):
self.char2int_vocab = {el[1]: el[0] for el in self.int2char_vocab.items()}
def build_tok2int_vocab(self):
self.tok2int_vocab = {el[1]: el[0] for el in self.int2tok_vocab.items()}
def build_response2emb_vocabulary(self):
for i in range(len(self.response2toks_vocab)):
self.response2emb_vocab[i] = None
def build_context2emb_vocabulary(self):
for i in range(len(self.context2toks_vocab)):
self.context2emb_vocab[i] = None
def conts2toks(self, conts_li):
toks_li = [self.context2toks_vocab[cont] for cont in conts_li]
return toks_li
def resps2toks(self, resps_li):
toks_li = [self.response2toks_vocab[resp] for resp in resps_li]
return toks_li
def make_toks(self, items_li, type):
if type == "context":
toks_li = self.conts2toks(items_li)
elif type == "response":
toks_li = self.resps2toks(items_li)
return toks_li
def make_ints(self, toks_li):
if self.tok_dynamic_batch:
msl = min(max([len(el) for el in toks_li]), self.max_sequence_length)
else:
msl = self.max_sequence_length
if self.char_dynamic_batch:
mtl = min(max(len(x) for el in toks_li for x in el), self.max_token_length)
else:
mtl = self.max_token_length
if self.token_embeddings and not self.char_embeddings:
return self.make_tok_ints(toks_li, msl)
elif not self.token_embeddings and self.char_embeddings:
return self.make_char_ints(toks_li, msl, mtl)
elif self.token_embeddings and self.char_embeddings:
tok_ints = self.make_tok_ints(toks_li, msl)
char_ints = self.make_char_ints(toks_li, msl, mtl)
return np.concatenate([np.expand_dims(tok_ints, axis=2), char_ints], axis=2)
def make_tok_ints(self, toks_li, msl):
ints_li = []
for toks in toks_li:
ints = []
for tok in toks:
index = self.tok2int_vocab.get(tok)
if self.tok2int_vocab.get(tok) is not None:
ints.append(index)
else:
ints.append(0)
ints_li.append(ints)
ints_li = pad_sequences(ints_li,
maxlen=msl,
padding=self.padding,
truncating=self.truncating)
return ints_li
def make_char_ints(self, toks_li, msl, mtl):
ints_li = np.zeros((len(toks_li), msl, mtl))
for i, toks in enumerate(toks_li):
if self.truncating == 'post':
toks = toks[:msl]
else:
toks = toks[-msl:]
for j, tok in enumerate(toks):
if self.padding == 'post':
k = j
else:
k = j + msl - len(toks)
ints = []
for char in tok:
index = self.char2int_vocab.get(char)
if index is not None:
ints.append(index)
else:
ints.append(0)
if self.char_trunc == 'post':
ints = ints[:mtl]
else:
ints = ints[-mtl:]
if self.char_pad == 'post':
ints_li[i, k, :len(ints)] = ints
else:
ints_li[i, k, -len(ints):] = ints
return ints_li
def save_int2char(self):
with self.char_save_path.open('w') as f:
f.write('\n'.join(['\t'.join([str(el[0]), el[1]]) for el in self.int2char_vocab.items()]))
def load_int2char(self):
with self.char_load_path.open('r') as f:
data = f.readlines()
self.int2char_vocab = {int(el.split('\t')[0]): el.split('\t')[1][:-1] for el in data}
def save_int2tok(self):
with self.tok_save_path.open('w') as f:
f.write('\n'.join(['\t'.join([str(el[0]), el[1]]) for el in self.int2tok_vocab.items()]))
def load_int2tok(self):
with self.tok_load_path.open('r') as f:
data = f.readlines()
self.int2tok_vocab = {int(el.split('\t')[0]): el.split('\t')[1][:-1] for el in data}
def save_context2toks(self):
with self.cont_save_path.open('w') as f:
f.write('\n'.join(['\t'.join([str(el[0]), ' '.join(el[1])]) for el in self.context2toks_vocab.items()]))
def load_context2toks(self):
with self.cont_load_path.open('r') as f:
data = f.readlines()
self.context2toks_vocab = {int(el.split('\t')[0]): el.split('\t')[1][:-1].split(' ') for el in data}
def save_response2toks(self):
with self.resp_save_path.open('w') as f:
f.write(
'\n'.join(['\t'.join([str(el[0]), ' '.join(el[1])]) for el in self.response2toks_vocab.items()]))
def load_response2toks(self):
with self.resp_load_path.open('r') as f:
data = f.readlines()
self.response2toks_vocab = {int(el.split('\t')[0]): el.split('\t')[1][:-1].split(' ') for el in data}
def save_cont(self):
context_embeddings = []
for i in range(len(self.context2emb_vocab)):
context_embeddings.append(self.context2emb_vocab[i])
context_embeddings = np.vstack(context_embeddings)
np.save(self.cemb_save_path, context_embeddings)
def load_cont(self):
context_embeddings_arr = np.load(self.cemb_load_path)
for i in range(context_embeddings_arr.shape[0]):
self.context2emb_vocab[i] = context_embeddings_arr[i]
def save_resp(self):
response_embeddings = []
for i in range(len(self.response2emb_vocab)):
response_embeddings.append(self.response2emb_vocab[i])
response_embeddings = np.vstack(response_embeddings)
np.save(self.remb_save_path, response_embeddings)
def load_resp(self):
response_embeddings_arr = np.load(self.remb_load_path)
for i in range(response_embeddings_arr.shape[0]):
self.response2emb_vocab[i] = response_embeddings_arr[i]