# Copyright 2017 Neural Networks and Deep Learning lab, MIPT
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import pickle
from logging import getLogger
from typing import Union, Tuple, List, Optional
from deeppavlov.core.common.errors import ConfigError
from deeppavlov.core.models.component import Component
from deeppavlov.core.models.nn_model import NNModel
from deeppavlov.core.models.serializable import Serializable
log = getLogger(__name__)
[docs]class Chainer(Component):
"""
Builds an agent/component pipeline from heterogeneous components (Rule-based/ML/DL). It allows to train
and infer models in a pipeline as a whole.
Attributes:
pipe: list of components and their input and output variable names for inference
train_pipe: list of components and their input and output variable names for training and evaluation
in_x: names of inputs for pipeline inference mode
out_params: names of pipeline inference outputs
in_y: names of additional inputs for pipeline training and evaluation modes
forward_map: list of all variables in chainer's memory after running every component in ``self.pipe``
train_map: list of all variables in chainer's memory after running every component in ``train_pipe.pipe``
main: reference to the main component
Args:
in_x: names of inputs for pipeline inference mode
out_params: names of pipeline inference outputs
in_y: names of additional inputs for pipeline training and evaluation modes
"""
def __init__(self, in_x: Union[str, list] = None, out_params: Union[str, list] = None,
in_y: Union[str, list] = None, *args, **kwargs) -> None:
self.pipe: List[Tuple[Tuple[List[str], List[str]], List[str], Component]] = []
self.train_pipe = []
if isinstance(in_x, str):
in_x = [in_x]
if isinstance(in_y, str):
in_y = [in_y]
if isinstance(out_params, str):
out_params = [out_params]
self.in_x = in_x or ['x']
self.in_y = in_y or ['y']
self.out_params = out_params or self.in_x
self.forward_map = set(self.in_x)
self.train_map = self.forward_map.union(self.in_y)
self.main = None
def append(self, component: Component, in_x: [str, list, dict]=None, out_params: [str, list]=None,
in_y: [str, list, dict]=None, main=False):
if isinstance(in_x, str):
in_x = [in_x]
if isinstance(in_y, str):
in_y = [in_y]
if isinstance(out_params, str):
out_params = [out_params]
in_x = in_x or self.in_x
if isinstance(in_x, dict):
x_keys, in_x = zip(*in_x.items())
else:
x_keys = []
out_params = out_params or in_x
if in_y is not None:
if isinstance(in_y, dict):
y_keys, in_y = zip(*in_y.items())
else:
y_keys = []
keys = x_keys + y_keys
if bool(x_keys) != bool(y_keys):
raise ConfigError('`in` and `in_y` for a component have to both be lists or dicts')
component: NNModel
main = True
assert self.train_map.issuperset(in_x+in_y), ('Arguments {} are expected but only {} are set'
.format(in_x+in_y, self.train_map))
preprocessor = Chainer(self.in_x, in_x+in_y, self.in_y)
for (t_in_x_keys, t_in_x), t_out, t_component in self.train_pipe:
if t_in_x_keys:
t_in_x = dict(zip(t_in_x_keys, t_in_x))
preprocessor.append(t_component, t_in_x, t_out)
def train_on_batch(*args, **kwargs):
preprocessed = preprocessor.compute(*args, **kwargs)
if len(in_x+in_y) == 1:
preprocessed = [preprocessed]
if keys:
return component.train_on_batch(**dict(zip(keys, preprocessed)))
else:
return component.train_on_batch(*preprocessed)
self.train_on_batch = train_on_batch
self.process_event = component.process_event
if main:
self.main = component
if self.forward_map.issuperset(in_x):
self.pipe.append(((x_keys, in_x), out_params, component))
self.forward_map = self.forward_map.union(out_params)
if self.train_map.issuperset(in_x):
self.train_pipe.append(((x_keys, in_x), out_params, component))
self.train_map = self.train_map.union(out_params)
else:
raise ConfigError('Arguments {} are expected but only {} are set'.format(in_x, self.train_map))
def compute(self, x, y=None, targets=None):
if targets is None:
targets = self.out_params
in_params = list(self.in_x)
if len(in_params) == 1:
args = [x]
else:
args = list(zip(*x))
if y is None:
pipe = self.pipe
else:
pipe = self.train_pipe
if len(self.in_y) == 1:
args.append(y)
else:
args += list(zip(*y))
in_params += self.in_y
return self._compute(*args, pipe=pipe, param_names=in_params, targets=targets)
def __call__(self, *args):
return self._compute(*args, param_names=self.in_x, pipe=self.pipe, targets=self.out_params)
@staticmethod
def _compute(*args, param_names, pipe, targets):
expected = set(targets)
final_pipe = []
for (in_keys, in_params), out_params, component in reversed(pipe):
if expected.intersection(out_params):
expected = expected - set(out_params) | set(in_params)
final_pipe.append(((in_keys, in_params), out_params, component))
final_pipe.reverse()
if not expected.issubset(param_names):
raise RuntimeError(f'{expected} are required to compute {targets} but were not found in memory or inputs')
pipe = final_pipe
mem = dict(zip(param_names, args))
del args
for (in_keys, in_params), out_params, component in pipe:
x = [mem[k] for k in in_params]
if in_keys:
res = component(**dict(zip(in_keys, x)))
else:
res = component(*x)
if len(out_params) == 1:
mem[out_params[0]] = res
else:
mem.update(zip(out_params, res))
res = [mem[k] for k in targets]
if len(res) == 1:
res = res[0]
return res
def get_main_component(self) -> Optional[Serializable]:
try:
return self.main or self.pipe[-1][-1]
except IndexError:
log.warning('Cannot get a main component for an empty chainer')
return None
def save(self) -> None:
main_component = self.get_main_component()
if isinstance(main_component, Serializable):
main_component.save()
def load(self) -> None:
for in_params, out_params, component in self.train_pipe:
if callable(getattr(component, 'load', None)):
component.load()
def reset(self) -> None:
for in_params, out_params, component in self.train_pipe:
if callable(getattr(component, 'reset', None)):
component.reset()
def destroy(self):
if hasattr(self, 'train_pipe'):
for in_params, out_params, component in self.train_pipe:
if callable(getattr(component, 'destroy', None)):
component.destroy()
self.train_pipe.clear()
if hasattr(self, 'pipe'):
self.pipe.clear()
super().destroy()
def serialize(self) -> bytes:
data = []
for in_params, out_params, component in self.train_pipe:
data.append(component.serialize())
return pickle.dumps(data, protocol=4)
def deserialize(self, data: bytes) -> None:
data = pickle.loads(data)
for in_params, out_params, component in self.train_pipe:
component.deserialize(data)