data.py 1.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354
  1. import tensorflow as tf
  2. from tensorflow.keras.utils import Sequence
  3. # This creates pool of cpu resources
  4. # physical_devices = tf.config.experimental.list_physical_devices("CPU")
  5. # tf.config.experimental.set_virtual_device_configuration(
  6. # physical_devices[0], [
  7. # tf.config.experimental.VirtualDeviceConfiguration(),
  8. # tf.config.experimental.VirtualDeviceConfiguration()
  9. # ])
  10. import misc
  11. class BinaryOneHotGenerator(Sequence):
  12. def __init__(self, size=1e5, shape=2):
  13. size = int(size)
  14. if size % shape:
  15. size += shape - (size % shape)
  16. self.size = size
  17. self.shape = shape
  18. self.x = None
  19. self.on_epoch_end()
  20. def on_epoch_end(self):
  21. x_train = misc.generate_random_bit_array(self.size).reshape((-1, self.shape))
  22. self.x = misc.bit_matrix2one_hot(x_train)
  23. def __len__(self):
  24. return self.size
  25. def __getitem__(self, idx):
  26. return self.x, self.x
  27. class BinaryGenerator(Sequence):
  28. def __init__(self, size=1e5, shape=2, dtype=tf.bool):
  29. size = int(size)
  30. if size % shape:
  31. size += shape - (size % shape)
  32. self.size = size
  33. self.shape = shape
  34. self.x = None
  35. self.dtype = dtype
  36. self.on_epoch_end()
  37. def on_epoch_end(self):
  38. x = misc.generate_random_bit_array(self.size).reshape((-1, self.shape))
  39. self.x = tf.convert_to_tensor(x, dtype=self.dtype)
  40. def __len__(self):
  41. return self.size
  42. def __getitem__(self, idx):
  43. return self.x, self.x