| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354 |
- import tensorflow as tf
- from tensorflow.keras.utils import Sequence
- # This creates pool of cpu resources
- # physical_devices = tf.config.experimental.list_physical_devices("CPU")
- # tf.config.experimental.set_virtual_device_configuration(
- # physical_devices[0], [
- # tf.config.experimental.VirtualDeviceConfiguration(),
- # tf.config.experimental.VirtualDeviceConfiguration()
- # ])
- import misc
- class BinaryOneHotGenerator(Sequence):
- def __init__(self, size=1e5, shape=2):
- size = int(size)
- if size % shape:
- size += shape - (size % shape)
- self.size = size
- self.shape = shape
- self.x = None
- self.on_epoch_end()
- def on_epoch_end(self):
- x_train = misc.generate_random_bit_array(self.size).reshape((-1, self.shape))
- self.x = misc.bit_matrix2one_hot(x_train)
- def __len__(self):
- return self.size
- def __getitem__(self, idx):
- return self.x, self.x
- class BinaryGenerator(Sequence):
- def __init__(self, size=1e5, shape=2, dtype=tf.bool):
- size = int(size)
- if size % shape:
- size += shape - (size % shape)
- self.size = size
- self.shape = shape
- self.x = None
- self.dtype = dtype
- self.on_epoch_end()
- def on_epoch_end(self):
- x = misc.generate_random_bit_array(self.size).reshape((-1, self.shape))
- self.x = tf.convert_to_tensor(x, dtype=self.dtype)
- def __len__(self):
- return self.size
- def __getitem__(self, idx):
- return self.x, self.x
|