Source code for deeppavlov.core.common.chainer

# Copyright 2017 Neural Networks and Deep Learning lab, MIPT
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import pickle
from itertools import islice
from logging import getLogger
from types import FunctionType
from typing import Union, Tuple, List, Optional, Hashable, Reversible

from deeppavlov.core.common.errors import ConfigError
from deeppavlov.core.models.component import Component
from deeppavlov.core.models.nn_model import NNModel
from deeppavlov.core.models.serializable import Serializable

log = getLogger(__name__)


[docs]class Chainer(Component): """ Builds a component pipeline from heterogeneous components (Rule-based/ML/DL). It allows to train and infer models in a pipeline as a whole. Attributes: pipe: list of components and their input and output variable names for inference train_pipe: list of components and their input and output variable names for training and evaluation in_x: names of inputs for pipeline inference mode out_params: names of pipeline inference outputs in_y: names of additional inputs for pipeline training and evaluation modes forward_map: list of all variables in chainer's memory after running every component in ``self.pipe`` train_map: list of all variables in chainer's memory after running every component in ``train_pipe.pipe`` main: reference to the main component Args: in_x: names of inputs for pipeline inference mode out_params: names of pipeline inference outputs in_y: names of additional inputs for pipeline training and evaluation modes """ def __init__(self, in_x: Union[str, list] = None, out_params: Union[str, list] = None, in_y: Union[str, list] = None, *args, **kwargs) -> None: self.pipe: List[Tuple[Tuple[List[str], List[str]], List[str], Component]] = [] self.train_pipe = [] if isinstance(in_x, str): in_x = [in_x] if isinstance(in_y, str): in_y = [in_y] if isinstance(out_params, str): out_params = [out_params] self.in_x = in_x or ['x'] self.in_y = in_y or ['y'] self.out_params = out_params or self.in_x self.forward_map = set(self.in_x) self.train_map = self.forward_map.union(self.in_y) self._components_dict = {} self.main = None def __getitem__(self, item): if isinstance(item, int): in_params, out_params, component = self.train_pipe[item] return component return self._components_dict[item] def _ipython_key_completions_(self): return self._components_dict.keys() def __repr__(self): reversed_components_dict = {v: f'{repr(k)}: ' for k, v in self._components_dict.items() if isinstance(v, Hashable)} components_list = [] for in_params, out_params, component in self.train_pipe: component_repr = repr(component) if isinstance(component, Hashable) and component in reversed_components_dict: component_repr = reversed_components_dict[component] + component_repr else: for k, v in self._components_dict.items(): if v is component: component_repr = f'{k}: {component_repr}' break components_list.append(component_repr) return f'Chainer[{", ".join(components_list)}]' def _repr_pretty_(self, p, cycle): """method that defines ``Struct``'s pretty printing rules for iPython Args: p (IPython.lib.pretty.RepresentationPrinter): pretty printer object cycle (bool): is ``True`` if pretty detected a cycle """ if cycle: p.text('Chainer(...)') else: with p.group(8, 'Chainer[', ']'): reversed_components_dict = {v: k for k, v in self._components_dict.items() if isinstance(v, Hashable)} # p.pretty(self.__prepare_repr()) for i, (in_params, out_params, component) in enumerate(self.train_pipe): if i > 0: p.text(',') p.breakable() if isinstance(component, Hashable) and component in reversed_components_dict: p.pretty(reversed_components_dict[component]) p.text(': ') else: for k, v in self._components_dict.items(): if v is component: p.pretty(k) p.text(': ') break p.pretty(component) def append(self, component: Union[Component, FunctionType], in_x: [str, list, dict] = None, out_params: [str, list] = None, in_y: [str, list, dict] = None, main: bool = False): if isinstance(in_x, str): in_x = [in_x] if isinstance(in_y, str): in_y = [in_y] if isinstance(out_params, str): out_params = [out_params] in_x = in_x or self.in_x if isinstance(in_x, dict): x_keys, in_x = zip(*in_x.items()) else: x_keys = [] out_params = out_params or in_x if in_y is not None: if isinstance(in_y, dict): y_keys, in_y = zip(*in_y.items()) else: y_keys = [] keys = x_keys + y_keys if bool(x_keys) != bool(y_keys): raise ConfigError('`in` and `in_y` for a component have to both be lists or dicts') component: NNModel main = True assert self.train_map.issuperset(in_x + in_y), ('Arguments {} are expected but only {} are set' .format(in_x + in_y, self.train_map)) preprocessor = Chainer(self.in_x, in_x + in_y, self.in_y) for (t_in_x_keys, t_in_x), t_out, t_component in self.train_pipe: if t_in_x_keys: t_in_x = dict(zip(t_in_x_keys, t_in_x)) preprocessor.append(t_component, t_in_x, t_out) def train_on_batch(*args, **kwargs): preprocessed = preprocessor.compute(*args, **kwargs) if len(in_x + in_y) == 1: preprocessed = [preprocessed] if keys: return component.train_on_batch(**dict(zip(keys, preprocessed))) else: return component.train_on_batch(*preprocessed) self.train_on_batch = train_on_batch self.process_event = component.process_event if main: self.main = component if self.forward_map.issuperset(in_x): self.pipe.append(((x_keys, in_x), out_params, component)) self.forward_map = self.forward_map.union(out_params) if self.train_map.issuperset(in_x): self.train_pipe.append(((x_keys, in_x), out_params, component)) self.train_map = self.train_map.union(out_params) else: raise ConfigError('Arguments {} are expected but only {} are set'.format(in_x, self.train_map)) def compute(self, x, y=None, targets=None): if targets is None: targets = self.out_params in_params = list(self.in_x) if len(in_params) == 1: args = [x] else: args = list(zip(*x)) if y is None: pipe = self.pipe else: pipe = self.train_pipe if len(self.in_y) == 1: args.append(y) else: args += list(zip(*y)) in_params += self.in_y return self._compute(*args, pipe=pipe, param_names=in_params, targets=targets)
[docs] def __call__(self, *args): return self._compute(*args, param_names=self.in_x, pipe=self.pipe, targets=self.out_params)
@staticmethod def _compute(*args, param_names, pipe, targets): expected = set(targets) final_pipe = [] for (in_keys, in_params), out_params, component in reversed(pipe): if expected.intersection(out_params): expected = expected - set(out_params) | set(in_params) final_pipe.append(((in_keys, in_params), out_params, component)) final_pipe.reverse() if not expected.issubset(param_names): raise RuntimeError(f'{expected} are required to compute {targets} but were not found in memory or inputs') pipe = final_pipe mem = dict(zip(param_names, args)) del args for (in_keys, in_params), out_params, component in pipe: x = [mem[k] for k in in_params] if in_keys: res = component.__call__(**dict(zip(in_keys, x))) else: res = component.__call__(*x) if len(out_params) == 1: mem[out_params[0]] = res else: mem.update(zip(out_params, res)) res = [mem[k] for k in targets] if len(res) == 1: res = res[0] return res
[docs] def batched_call(self, *args: Reversible, batch_size: int = 16) -> Union[list, Tuple[list, ...]]: """ Partitions data into mini-batches and applies :meth:`__call__` to each batch. Args: args: input data, each element of the data corresponds to a single model inputs sequence. batch_size: the size of a batch. Returns: the model output as if the data was passed to the :meth:`__call__` method. """ args = [iter(arg) for arg in args] answer = [[] for _ in self.out_params] while True: batch = [list(islice(arg, batch_size)) for arg in args] if not any(batch): # empty batch, reached the end break curr_answer = self.__call__(*batch) if len(self.out_params) == 1: curr_answer = [curr_answer] for y, curr_y in zip(answer, curr_answer): y.extend(curr_y) if len(self.out_params) == 1: answer = answer[0] return answer
def get_main_component(self) -> Optional[Serializable]: try: return self.main or self.pipe[-1][-1] except IndexError: log.warning('Cannot get a main component for an empty chainer') return None def save(self) -> None: main_component = self.get_main_component() if isinstance(main_component, Serializable): main_component.save() def load(self) -> None: for in_params, out_params, component in self.train_pipe: if callable(getattr(component, 'load', None)): component.load() def reset(self) -> None: for in_params, out_params, component in self.train_pipe: if callable(getattr(component, 'reset', None)): component.reset() def destroy(self): if hasattr(self, 'train_pipe'): for in_params, out_params, component in self.train_pipe: if callable(getattr(component, 'destroy', None)): component.destroy() self.train_pipe.clear() if hasattr(self, 'pipe'): self.pipe.clear() super().destroy()