# vim: ft=python fileencoding=utf-8 sts=4 sw=4 et:
from typing import TYPE_CHECKING, List, Tuple
import numpy as np
from tensorflow import keras
from .datatype import NumericArray
if TYPE_CHECKING:
from .datatype import DataType
[docs]class Dataset(keras.utils.Sequence):
def __init__(
self,
x_data: "np.ndarray" = None,
y_data: "np.ndarray" = None,
x_type: "DataType" = None,
y_type: "DataType" = None,
batch_size: int = 32,
shuffled: bool = False,
):
if x_data is None:
x_data = np.empty(0)
if y_data is None:
y_data = np.empty(0)
if x_type is None:
x_type = NumericArray()
if y_type is None:
y_type = NumericArray()
# Data arrays
self.__x, self.__y = x_data, y_data
# Data types
self.x_type, self.y_type = x_type, y_type
# Class attributes
self.__indexes = np.arange(self.__x.shape[0])
self.shuffled = shuffled # type: ignore
self.batch_size = batch_size
@property
def shuffled(self) -> bool:
"""
Check if the dataset is shuffled (dataset items randomly sorted)
"""
return self.__shuffled
@shuffled.setter # type: ignore
def shuffled(self, toggle: bool):
self.__shuffled = toggle
if self.__shuffled:
self.shuffle()
else:
self.__indexes = np.arange(self.__x.shape[0])
[docs] def shuffle(self):
self.__shuffled = True
np.random.shuffle(self.__indexes)
[docs] def delete_rows(self, start: int, n: int = 1):
self.__x = np.delete(self.__x, self.__indexes[start : start + n])
self.__y = np.delete(self.__y, self.__indexes[start : start + n])
self.__indexes = np.delete(self.__indexes, range(start, start + n - 1))
[docs] def head(self, n: int = 10) -> Tuple[List, List]:
"""
Returns the first `n` items on the dataset.
"""
return self.items(0, n)
[docs] def items(self, start: int, end: int) -> Tuple["np.array", "np.array"]:
"""
Return the `n` elements between start and end as a tuple of (x, y) items
Range is EXCLUSIVE [start, end)
"""
start = max(start, 0)
end = min(end, len(self.__indexes))
indexes = self.__indexes[start:end]
x_set, y_set = self.__preprocess_data(self.__x[indexes], self.__y[indexes])
return x_set, y_set
[docs] def __len__(self) -> int:
"""
Return the length of the dataset.
"""
return int(np.ceil(len(self.__x) / float(self.batch_size)))
[docs] def __getitem__(self, idx: int) -> Tuple["np.array", "np.array"]:
"""
Return the batch of items starting on `idx`.
"""
batch_start = idx * self.batch_size
batch_end = (idx + 1) * self.batch_size
batch_indexes = self.__indexes[batch_start:batch_end]
batch_x = self.__x[batch_indexes]
batch_y = self.__y[batch_indexes]
batch_x, batch_y = self.__preprocess_data(batch_x, batch_y)
return batch_x, batch_y
def __preprocess_data(
self, x_data: "np.array", y_data: "np.array"
) -> Tuple["np.array", "np.array"]:
"""
Preprocess the data. For example, if the image is a path to a file, load it and
return the corresponding array.
"""
x_data = np.array([self.x_type.process(element) for element in x_data])
y_data = np.array([self.y_type.process(element) for element in y_data])
return (x_data, y_data)