# Copyright 2017 Neural Networks and Deep Learning lab, MIPT
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import inspect
from abc import abstractmethod
from copy import deepcopy
from logging import getLogger
from typing import Optional, List, Union
import numpy as np
import tensorflow as tf
from keras import backend as K
from overrides import overrides
from deeppavlov.core.models.nn_model import NNModel
from deeppavlov.core.models.tf_backend import TfModelMeta
from deeppavlov.core.models.lr_scheduled_model import LRScheduledModel
log = getLogger(__name__)
[docs]class KerasModel(NNModel, metaclass=TfModelMeta):
"""
Builds Keras model with TensorFlow backend.
Attributes:
epochs_done: number of epochs that were done
batches_seen: number of epochs that were seen
train_examples_seen: number of training samples that were seen
sess: tf session
"""
def __init__(self, **kwargs) -> None:
"""
Initialize model using keyword parameters
Args:
kwargs (dict): Dictionary with model parameters
"""
self.epochs_done = 0
self.batches_seen = 0
self.train_examples_seen = 0
super().__init__(save_path=kwargs.get("save_path"),
load_path=kwargs.get("load_path"),
mode=kwargs.get("mode"))
@staticmethod
def _config_session():
"""
Configure session for particular device
Returns:
tensorflow.Session
"""
config = tf.ConfigProto()
config.gpu_options.allow_growth = True
config.gpu_options.visible_device_list = '0'
return tf.Session(config=config)
@abstractmethod
def load(self, *args, **kwargs) -> None:
pass
@abstractmethod
def save(self, *args, **kwargs) -> None:
pass
def process_event(self, event_name: str, data: dict) -> None:
"""
Process event after epoch
Args:
event_name: whether event is send after epoch or batch.
Set of values: ``"after_epoch", "after_batch"``
data: event data (dictionary)
Returns:
None
"""
if event_name == "after_epoch":
self.epochs_done = data["epochs_done"]
self.batches_seen = data["batches_seen"]
self.train_examples_seen = data["train_examples_seen"]
return
class LRScheduledKerasModel(LRScheduledModel, KerasModel):
"""
KerasModel enhanced with optimizer, learning rate and momentum
management and search.
"""
def __init__(self, **kwargs):
"""
Initialize model with given parameters
Args:
**kwargs: dictionary of parameters
"""
if isinstance(kwargs.get("learning_rate"), float) and isinstance(kwargs.get("learning_rate_decay"), float):
KerasModel.__init__(self, **kwargs)
else:
KerasModel.__init__(self, **kwargs)
LRScheduledModel.__init__(self, **kwargs)
@abstractmethod
def get_optimizer(self):
"""
Return instance of keras optimizer
Args:
None
"""
pass
@overrides
def _init_learning_rate_variable(self):
"""
Initialize learning rate
Returns:
None
"""
return None
@overrides
def _init_momentum_variable(self):
"""
Initialize momentum
Returns:
None
"""
return None
@overrides
def get_learning_rate_variable(self):
"""
Extract value of learning rate from optimizer
Returns:
learning rate value
"""
return self.get_optimizer().lr
@overrides
def get_momentum_variable(self):
"""
Extract values of momentum variables from optimizer
Returns:
optimizer's `rho` or `beta_1`
"""
optimizer = self.get_optimizer()
if hasattr(optimizer, 'rho'):
return optimizer.rho
elif hasattr(optimizer, 'beta_1'):
return optimizer.beta_1
return None
@overrides
def _update_graph_variables(self, learning_rate: float = None, momentum: float = None):
"""
Update graph variables setting giving `learning_rate` and `momentum`
Args:
learning_rate: learning rate value to be set in graph (set if not None)
momentum: momentum value to be set in graph (set if not None)
Returns:
None
"""
if learning_rate is not None:
K.set_value(self.get_learning_rate_variable(), learning_rate)
# log.info(f"Learning rate = {learning_rate}")
if momentum is not None:
K.set_value(self.get_momentum_variable(), momentum)
# log.info(f"Momentum = {momentum}")
def process_event(self, event_name: str, data: dict):
"""
Process event after epoch
Args:
event_name: whether event is send after epoch or batch.
Set of values: ``"after_epoch", "after_batch"``
data: event data (dictionary)
Returns:
None
"""
if (isinstance(self.opt.get("learning_rate", None), float) and
isinstance(self.opt.get("learning_rate_decay", None), float)):
pass
else:
if event_name == 'after_train_log':
if (self.get_learning_rate_variable() is not None) and ('learning_rate' not in data):
data['learning_rate'] = float(K.get_value(self.get_learning_rate_variable()))
# data['learning_rate'] = self._lr
if (self.get_momentum_variable() is not None) and ('momentum' not in data):
data['momentum'] = float(K.get_value(self.get_momentum_variable()))
# data['momentum'] = self._mom
else:
super().process_event(event_name, data)