# Copyright 2017 Neural Networks and Deep Learning lab, MIPT
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import sqlite3
from typing import List, Any, Dict, Optional, Union
from random import Random
from pathlib import Path
from overrides import overrides
from deeppavlov.core.common.log import get_logger
from deeppavlov.core.common.registry import register
from deeppavlov.core.data.data_fitting_iterator import DataFittingIterator
from deeppavlov.core.commands.utils import expand_path
logger = get_logger(__name__)
[docs]@register('sqlite_iterator')
class SQLiteDataIterator(DataFittingIterator):
"""Iterate over SQLite database.
Gen batches from SQLite data.
Get document ids and document.
Args:
load_path: a path to local DB file
batch_size: a number of samples in a single batch
shuffle: whether to shuffle data during batching
seed: random seed for data shuffling
Attributes:
connect: a DB connection
db_name: a DB name
doc_ids: DB document ids
doc2index: a dictionary of document indices and their titles
batch_size: a number of samples in a single batch
shuffle: whether to shuffle data during batching
random: an instance of :class:`Random` class.
"""
def __init__(self, load_path: Union[str, Path], batch_size: Optional[int] = None,
shuffle: Optional[bool] = None, seed: Optional[int] = None, **kwargs) -> None:
load_path = str(expand_path(load_path))
logger.info("Connecting to database, path: {}".format(load_path))
try:
self.connect = sqlite3.connect(load_path, check_same_thread=False)
except sqlite3.OperationalError as e:
e.args = e.args + ("Check that DB path exists and is a valid DB file",)
raise e
try:
self.db_name = self.get_db_name()
except TypeError as e:
e.args = e.args + (
'Check that DB path was created correctly and is not empty. '
'Check that a correct dataset_format is passed to the ODQAReader config',)
raise e
self.doc_ids = self.get_doc_ids()
self.doc2index = self.map_doc2idx()
self.batch_size = batch_size
self.shuffle = shuffle
self.random = Random(seed)
@overrides
def get_doc_ids(self) -> List[Any]:
"""Get document ids.
Returns:
document ids
"""
cursor = self.connect.cursor()
cursor.execute('SELECT id FROM {}'.format(self.db_name))
ids = [ids[0] for ids in cursor.fetchall()]
cursor.close()
return ids
def get_db_name(self) -> str:
"""Get DB name.
Returns:
DB name
"""
cursor = self.connect.cursor()
cursor.execute("SELECT name FROM sqlite_master WHERE type='table';")
assert cursor.arraysize == 1
name = cursor.fetchone()[0]
cursor.close()
return name
def map_doc2idx(self) -> Dict[int, Any]:
"""Map DB ids to integer ids.
Returns:
a dictionary of document titles and correspondent integer indices
"""
doc2idx = {doc_id: i for i, doc_id in enumerate(self.doc_ids)}
logger.info(
"SQLite iterator: The size of the database is {} documents".format(len(doc2idx)))
return doc2idx
@overrides
def get_doc_content(self, doc_id: Any) -> Optional[str]:
"""Get document content by id.
Args:
doc_id: a document id
Returns:
document content if success, else raise Exception
"""
cursor = self.connect.cursor()
cursor.execute(
"SELECT text FROM {} WHERE id = ?".format(self.db_name),
(doc_id,)
)
result = cursor.fetchone()
cursor.close()
return result if result is None else result[0]