data.py 969 B

1234567891011121314151617181920212223242526272829303132
  1. import tensorflow as tf
  2. from tensorflow.keras.utils import Sequence
  3. # This creates pool of cpu resources
  4. # physical_devices = tf.config.experimental.list_physical_devices("CPU")
  5. # tf.config.experimental.set_virtual_device_configuration(
  6. # physical_devices[0], [
  7. # tf.config.experimental.VirtualDeviceConfiguration(),
  8. # tf.config.experimental.VirtualDeviceConfiguration()
  9. # ])
  10. import misc
  11. class BinaryOneHotGenerator(Sequence):
  12. def __init__(self, size=1e5, shape=2):
  13. size = int(size)
  14. if size % shape:
  15. size += shape - (size % shape)
  16. self.size = size
  17. self.shape = shape
  18. self.x = None
  19. self.on_epoch_end()
  20. def on_epoch_end(self):
  21. x_train = misc.generate_random_bit_array(self.size).reshape((-1, self.shape))
  22. self.x = misc.bit_matrix2one_hot(x_train)
  23. def __len__(self):
  24. return self.size
  25. def __getitem__(self, idx):
  26. return self.x, self.x