import tensorflow as tf from tensorflow.keras.utils import Sequence from sklearn.preprocessing import OneHotEncoder import numpy as np # This creates pool of cpu resources # physical_devices = tf.config.experimental.list_physical_devices("CPU") # tf.config.experimental.set_virtual_device_configuration( # physical_devices[0], [ # tf.config.experimental.VirtualDeviceConfiguration(), # tf.config.experimental.VirtualDeviceConfiguration() # ]) import misc class BinaryOneHotGenerator(Sequence): def __init__(self, size=1e5, shape=2): size = int(size) if size % shape: size += shape - (size % shape) self.size = size self.shape = shape self.x = None self.on_epoch_end() def on_epoch_end(self): x_train = misc.generate_random_bit_array(self.size).reshape((-1, self.shape)) self.x = misc.bit_matrix2one_hot(x_train) def __len__(self): return self.size def __getitem__(self, idx): return self.x, self.x class BinaryTimeDistributedOneHotGenerator(Sequence): def __init__(self, size=1e5, cardinality=32, blocks=9): self.size = int(size) self.cardinality = cardinality self.x = None self.encoder = OneHotEncoder( handle_unknown='ignore', sparse=False, categories=[np.arange(self.cardinality)] ) self.middle = int((blocks - 1) / 2) self.blocks = blocks self.on_epoch_end() def on_epoch_end(self): rand_int = np.random.randint(self.cardinality, size=(self.size * self.blocks, 1)) out = self.encoder.fit_transform(rand_int) self.x = np.reshape(out, (self.size, self.blocks, self.cardinality)) def __len__(self): return self.size @property def y(self): return self.x[:, self.middle, :] def __getitem__(self, idx): return self.x, self.y class BinaryGenerator(Sequence): def __init__(self, size=1e5, shape=2, dtype=tf.bool): size = int(size) if size % shape: size += shape - (size % shape) self.size = size self.shape = shape self.x = None self.dtype = dtype self.on_epoch_end() def on_epoch_end(self): x = misc.generate_random_bit_array(self.size).reshape((-1, self.shape)) self.x = tf.convert_to_tensor(x, dtype=self.dtype) def __len__(self): return self.size def __getitem__(self, idx): return self.x, self.x