autoencoder.py 6.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185
  1. import matplotlib.pyplot as plt
  2. import numpy as np
  3. import tensorflow as tf
  4. from sklearn.metrics import accuracy_score
  5. from tensorflow.keras import layers, losses
  6. from tensorflow.keras.models import Model
  7. from functools import partial
  8. import misc
  9. import defs
  10. from models import basic
  11. latent_dim = 64
  12. print("# GPUs Available: ", len(tf.config.experimental.list_physical_devices('GPU')))
  13. class AutoencoderMod(defs.Modulator):
  14. def __init__(self, autoencoder):
  15. super().__init__(2**autoencoder.N)
  16. self.autoencoder = autoencoder
  17. def forward(self, binary: np.ndarray) -> defs.Signal:
  18. reshaped = binary.reshape((-1, self.N))
  19. reshaped_ho = misc.bit_matrix2one_hot(reshaped)
  20. encoded = self.autoencoder.encoder(reshaped_ho)
  21. x = encoded.numpy()
  22. x2 = x * 2 - 1
  23. f = np.zeros(x2.shape[0])
  24. x3 = misc.rect2polar(np.c_[x2[:, 0], x2[:, 1], f])
  25. return defs.Signal(x3)
  26. class AutoencoderDemod(defs.Demodulator):
  27. def __init__(self, autoencoder):
  28. super().__init__(2**autoencoder.N)
  29. self.autoencoder = autoencoder
  30. def forward(self, values: defs.Signal) -> np.ndarray:
  31. decoded = self.autoencoder.decoder(values.rect).numpy()
  32. result = misc.int2bit_array(decoded.argmax(axis=1), self.N)
  33. return result.reshape(-1, )
  34. class Autoencoder(Model):
  35. def __init__(self, N, noise):
  36. super(Autoencoder, self).__init__()
  37. self.N = N
  38. self.encoder = tf.keras.Sequential()
  39. self.encoder.add(tf.keras.Input(shape=(2 ** N,), dtype=bool))
  40. self.encoder.add(layers.Dense(units=2 ** (N + 1)))
  41. # self.encoder.add(layers.Dropout(0.2))
  42. self.encoder.add(layers.Dense(units=2 ** (N + 1)))
  43. self.encoder.add(layers.Dense(units=2, activation="sigmoid"))
  44. # self.encoder.add(layers.ReLU(max_value=1.0))
  45. self.decoder = tf.keras.Sequential()
  46. self.decoder.add(tf.keras.Input(shape=(2,)))
  47. self.decoder.add(layers.Dense(units=2 ** (N + 1)))
  48. # self.decoder.add(layers.Dropout(0.2))
  49. self.decoder.add(layers.Dense(units=2 ** (N + 1)))
  50. self.decoder.add(layers.Dense(units=2 ** N, activation="softmax"))
  51. # self.randomiser = tf.random_normal_initializer(mean=0.0, stddev=0.1, seed=None)
  52. self.mod = None
  53. self.demod = None
  54. self.compiled = False
  55. # Divide by 2 because encoder outputs values between 0 and 1 instead of -1 and 1
  56. self.noise = noise #10 ** (noise / 10) # / 2
  57. # self.decoder.add(layers.Softmax(units=4, dtype=bool))
  58. # [
  59. # layers.Input(shape=(28, 28, 1)),
  60. # layers.Conv2D(16, (3, 3), activation='relu', padding='same', strides=2),
  61. # layers.Conv2D(8, (3, 3), activation='relu', padding='same', strides=2)
  62. # ])
  63. # self.decoder = tf.keras.Sequential([
  64. # layers.Conv2DTranspose(8, kernel_size=3, strides=2, activation='relu', padding='same'),
  65. # layers.Conv2DTranspose(16, kernel_size=3, strides=2, activation='relu', padding='same'),
  66. # layers.Conv2D(1, kernel_size=(3, 3), activation='sigmoid', padding='same')
  67. # ])
  68. def call(self, x, **kwargs):
  69. chan = basic.AWGNChannel(self.noise)
  70. signal = self.encoder(x)
  71. signal = signal * 2 - 1
  72. signal = chan.forward_tensor(signal)
  73. # encoded = encoded * 2 - 1
  74. # encoded = tf.clip_by_value(encoded, clip_value_min=0, clip_value_max=1, name=None)
  75. # noise = self.randomiser(shape=(-1, 2), dtype=tf.float32)
  76. # noise = np.random.normal(0, 1, (1, 2)) * self.noise
  77. # noisy = tf.convert_to_tensor(noise, dtype=tf.float32)
  78. decoded = self.decoder(signal)
  79. return decoded
  80. def train(self, samples=1e6):
  81. if samples % self.N:
  82. samples += self.N - (samples % self.N)
  83. x_train = misc.generate_random_bit_array(samples).reshape((-1, self.N))
  84. x_train_ho = misc.bit_matrix2one_hot(x_train)
  85. test_samples = samples * 0.3
  86. if test_samples % self.N:
  87. test_samples += self.N - (test_samples % self.N)
  88. x_test_array = misc.generate_random_bit_array(test_samples)
  89. x_test = x_test_array.reshape((-1, self.N))
  90. x_test_ho = misc.bit_matrix2one_hot(x_test)
  91. if not self.compiled:
  92. self.compile(optimizer='adam', loss=losses.MeanSquaredError())
  93. self.compiled = True
  94. self.fit(x_train_ho, x_train_ho, shuffle=False, validation_data=(x_test_ho, x_test_ho))
  95. # encoded_data = self.encoder(x_test_ho)
  96. # decoded_data = self.decoder(encoded_data).numpy()
  97. def get_modulator(self):
  98. if self.mod is None:
  99. self.mod = AutoencoderMod(self)
  100. return self.mod
  101. def get_demodulator(self):
  102. if self.demod is None:
  103. self.demod = AutoencoderDemod(self)
  104. return self.demod
  105. def view_encoder(encoder, N, samples=1000):
  106. test_values = misc.generate_random_bit_array(samples).reshape((-1, N))
  107. test_values_ho = misc.bit_matrix2one_hot(test_values)
  108. mvector = np.array([2 ** i for i in range(N)], dtype=int)
  109. symbols = (test_values * mvector).sum(axis=1)
  110. encoded = encoder(test_values_ho).numpy()
  111. # encoded = misc.polar2rect(encoded)
  112. for i in range(2 ** N):
  113. xy = encoded[symbols == i]
  114. plt.plot(xy[:, 0], xy[:, 1], 'x', markersize=12, label=format(i, f'0{N}b'))
  115. plt.annotate(xy=[xy[:, 0].mean() + 0.01, xy[:, 1].mean() + 0.01], text=format(i, f'0{N}b'))
  116. plt.xlabel('Real')
  117. plt.ylabel('Imaginary')
  118. plt.title("Autoencoder generated alphabet")
  119. # plt.legend()
  120. plt.show()
  121. pass
  122. if __name__ == '__main__':
  123. # (x_train, _), (x_test, _) = fashion_mnist.load_data()
  124. #
  125. # x_train = x_train.astype('float32') / 255.
  126. # x_test = x_test.astype('float32') / 255.
  127. #
  128. # print(f"Train data: {x_train.shape}")
  129. # print(f"Test data: {x_test.shape}")
  130. n = 4
  131. samples = 1e6
  132. x_train = misc.generate_random_bit_array(samples).reshape((-1, n))
  133. x_train_ho = misc.bit_matrix2one_hot(x_train)
  134. x_test_array = misc.generate_random_bit_array(samples * 0.3)
  135. x_test = x_test_array.reshape((-1, n))
  136. x_test_ho = misc.bit_matrix2one_hot(x_test)
  137. autoencoder = Autoencoder(n, -8)
  138. autoencoder.compile(optimizer='adam', loss=losses.MeanSquaredError())
  139. autoencoder.fit(x_train_ho, x_train_ho,
  140. epochs=1,
  141. shuffle=False,
  142. validation_data=(x_test_ho, x_test_ho))
  143. encoded_data = autoencoder.encoder(x_test_ho)
  144. decoded_data = autoencoder.decoder(encoded_data).numpy()
  145. result = misc.int2bit_array(decoded_data.argmax(axis=1), n)
  146. print("Accuracy: %.4f" % accuracy_score(x_test_array, result.reshape(-1, )))
  147. view_encoder(autoencoder.encoder, n)
  148. pass