Source code for deeppavlov.core.commands.infer

# Copyright 2017 Neural Networks and Deep Learning lab, MIPT
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# See the License for the specific language governing permissions and
# limitations under the License.

from pathlib import Path
from typing import Optional

from deeppavlov.core.commands.utils import set_deeppavlov_root, import_packages
from deeppavlov.core.common.chainer import Chainer
from deeppavlov.core.common.file import read_json

from deeppavlov.core.agent.agent import Agent
from deeppavlov.core.common.params import from_params
from deeppavlov.core.common.log import get_logger

log = get_logger(__name__)

[docs]def build_model_from_config(config: [str, Path, dict], mode: str = 'infer', load_trained: bool = False, as_component: bool = False) -> Chainer: """Build and return the model described in corresponding configuration file.""" if isinstance(config, (str, Path)): config = read_json(config) set_deeppavlov_root(config) import_packages(config.get('metadata', {}).get('imports', [])) model_config = config['chainer'] model = Chainer(model_config['in'], model_config['out'], model_config.get('in_y'), as_component=as_component) for component_config in model_config['pipe']: if load_trained and ('fit_on' in component_config or 'in_y' in component_config): try: component_config['load_path'] = component_config['save_path'] except KeyError: log.warning('No "save_path" parameter for the {} component, so "load_path" will not be renewed' .format(component_config.get('name', component_config.get('ref', 'UNKNOWN')))) component = from_params(component_config, mode=mode) if 'in' in component_config: c_in = component_config['in'] c_out = component_config['out'] in_y = component_config.get('in_y', None) main = component_config.get('main', False) model.append(component, c_in, c_out, in_y, main) return model
[docs]def build_agent_from_config(config_path: str) -> Agent: """Build and return the agent described in corresponding configuration file.""" config = read_json(config_path) skill_configs = config['skills'] commutator_config = config['commutator'] return Agent(skill_configs, commutator_config)
[docs]def interact_agent(config_path: str) -> None: """Start interaction with the agent described in corresponding configuration file.""" a = build_agent_from_config(config_path) commutator = from_params(a.commutator_config) models = [build_model_from_config(sk) for sk in a.skill_configs] while True: # get input from user context = input(':: ') # check for exit command if context == 'exit' or context == 'stop' or context == 'quit' or context == 'q': return predictions = [] for model in models: predictions.append({model.__class__.__name__: model.infer(context, )}) idx, name, pred = commutator.infer(predictions, ) print('>>', pred) a.history.append({'context': context, "predictions": predictions, "winner": {"idx": idx, "model": name, "prediction": pred}}) log.debug("Current history: {}".format(a.history))
[docs]def interact_model(config_path: str) -> None: """Start interaction with the model described in corresponding configuration file.""" config = read_json(config_path) model = build_model_from_config(config) while True: args = [] for in_x in model.in_x: args.append(input('{}::'.format(in_x))) # check for exit command if args[-1] == 'exit' or args[-1] == 'stop' or args[-1] == 'quit' or args[-1] == 'q': return if len(args) == 1: pred = model(args) else: pred = model([args]) print('>>', *pred)
[docs]def predict_on_stream(config_path: str, batch_size: int = 1, file_path: Optional[str] = None) -> None: """Make a prediction with the component described in corresponding configuration file.""" import sys import json from itertools import islice if file_path is None or file_path == '-': if sys.stdin.isatty(): raise RuntimeError('To process data from terminal please use interact mode') f = sys.stdin else: f = open(file_path, encoding='utf8') config = read_json(config_path) model: Chainer = build_model_from_config(config) args_count = len(model.in_x) while True: batch = (l.strip() for l in islice(f, batch_size*args_count)) if args_count > 1: batch = zip(*[batch]*args_count) batch = list(batch) if not batch: break for res in model(batch): if type(res).__module__ == 'numpy': res = res.tolist() if not isinstance(res, str): res = json.dumps(res, ensure_ascii=False) print(res, flush=True) if f is not sys.stdin: f.close()