autoencoder.py 9.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253
  1. import matplotlib.pyplot as plt
  2. import numpy as np
  3. import tensorflow as tf
  4. from sklearn.metrics import accuracy_score
  5. from sklearn.model_selection import train_test_split
  6. from tensorflow.keras import layers, losses
  7. from tensorflow.keras.models import Model
  8. from tensorflow.python.keras.layers import LeakyReLU, ReLU
  9. from functools import partial
  10. import misc
  11. import defs
  12. from models import basic
  13. import os
  14. latent_dim = 64
  15. print("# GPUs Available: ", len(tf.config.experimental.list_physical_devices('GPU')))
  16. class AutoencoderMod(defs.Modulator):
  17. def __init__(self, autoencoder):
  18. super().__init__(2 ** autoencoder.N)
  19. self.autoencoder = autoencoder
  20. def forward(self, binary: np.ndarray) -> defs.Signal:
  21. reshaped = binary.reshape((-1, self.N))
  22. reshaped_ho = misc.bit_matrix2one_hot(reshaped)
  23. encoded = self.autoencoder.encoder(reshaped_ho)
  24. x = encoded.numpy()
  25. x2 = x * 2 - 1
  26. f = np.zeros(x2.shape[0])
  27. x3 = misc.rect2polar(np.c_[x2[:, 0], x2[:, 1], f])
  28. return defs.Signal(x3)
  29. class AutoencoderDemod(defs.Demodulator):
  30. def __init__(self, autoencoder):
  31. super().__init__(2 ** autoencoder.N)
  32. self.autoencoder = autoencoder
  33. def forward(self, values: defs.Signal) -> np.ndarray:
  34. decoded = self.autoencoder.decoder(values.rect).numpy()
  35. result = misc.int2bit_array(decoded.argmax(axis=1), self.N)
  36. return result.reshape(-1, )
  37. class Autoencoder(Model):
  38. def __init__(self, N, noise):
  39. super(Autoencoder, self).__init__()
  40. self.N = N
  41. self.encoder = tf.keras.Sequential()
  42. self.encoder.add(tf.keras.Input(shape=(2 ** N,), dtype=bool))
  43. self.encoder.add(layers.Dense(units=2 ** (N + 1)))
  44. self.encoder.add(LeakyReLU(alpha=0.001))
  45. # self.encoder.add(layers.Dropout(0.2))
  46. self.encoder.add(layers.Dense(units=2 ** (N + 1)))
  47. self.encoder.add(LeakyReLU(alpha=0.001))
  48. self.encoder.add(layers.Dense(units=2, activation="tanh"))
  49. # self.encoder.add(layers.ReLU(max_value=1.0))
  50. self.decoder = tf.keras.Sequential()
  51. self.decoder.add(tf.keras.Input(shape=(2,)))
  52. self.decoder.add(layers.Dense(units=2 ** (N + 1)))
  53. # leaky relu with alpha=1 gives by far best results
  54. self.decoder.add(LeakyReLU(alpha=1))
  55. self.decoder.add(layers.Dense(units=2 ** N, activation="softmax"))
  56. # self.randomiser = tf.random_normal_initializer(mean=0.0, stddev=0.1, seed=None)
  57. self.mod = None
  58. self.demod = None
  59. self.compiled = False
  60. # Divide by 2 because encoder outputs values between 0 and 1 instead of -1 and 1
  61. self.noise = noise #10 ** (noise / 10) # / 2
  62. # self.decoder.add(layers.Softmax(units=4, dtype=bool))
  63. # [
  64. # layers.Input(shape=(28, 28, 1)),
  65. # layers.Conv2D(16, (3, 3), activation='relu', padding='same', strides=2),
  66. # layers.Conv2D(8, (3, 3), activation='relu', padding='same', strides=2)
  67. # ])
  68. # self.decoder = tf.keras.Sequential([
  69. # layers.Conv2DTranspose(8, kernel_size=3, strides=2, activation='relu', padding='same'),
  70. # layers.Conv2DTranspose(16, kernel_size=3, strides=2, activation='relu', padding='same'),
  71. # layers.Conv2D(1, kernel_size=(3, 3), activation='sigmoid', padding='same')
  72. # ])
  73. def call(self, x, **kwargs):
  74. chan = basic.AWGNChannel(self.noise)
  75. signal = self.encoder(x)
  76. signal = signal * 2 - 1
  77. signal = chan.forward_tensor(signal)
  78. # encoded = encoded * 2 - 1
  79. # encoded = tf.clip_by_value(encoded, clip_value_min=0, clip_value_max=1, name=None)
  80. # noise = self.randomiser(shape=(-1, 2), dtype=tf.float32)
  81. # noise = np.random.normal(0, 1, (1, 2)) * self.noise
  82. # noisy = tf.convert_to_tensor(noise, dtype=tf.float32)
  83. decoded = self.decoder(signal)
  84. return decoded
  85. def fit_encoder(self, modulation, sample_size, train_size=0.8, epochs=1, batch_size=1, shuffle=False):
  86. alphabet = basic.load_alphabet(modulation, polar=False)
  87. if not alphabet.shape[0] == self.N ** 2:
  88. raise Exception("Cardinality of modulation scheme is different from cardinality of autoencoder!")
  89. x_train = np.random.randint(self.N ** 2, size=int(sample_size * train_size))
  90. y_train = alphabet[x_train]
  91. x_train_ho = np.zeros((int(sample_size * train_size), self.N ** 2))
  92. for idx, x in np.ndenumerate(x_train):
  93. x_train_ho[idx, x] = 1
  94. x_test = np.random.randint(self.N ** 2, size=int(sample_size * (1 - train_size)))
  95. y_test = alphabet[x_test]
  96. x_test_ho = np.zeros((int(sample_size * (1 - train_size)), self.N ** 2))
  97. for idx, x in np.ndenumerate(x_test):
  98. x_test_ho[idx, x] = 1
  99. self.encoder.compile(optimizer='adam', loss=tf.keras.losses.MeanSquaredError())
  100. self.encoder.fit(x_train_ho, y_train,
  101. epochs=epochs,
  102. batch_size=batch_size,
  103. shuffle=shuffle,
  104. validation_data=(x_test_ho, y_test))
  105. def fit_decoder(self, modulation, samples):
  106. samples = int(samples * 1.3)
  107. demod = basic.AlphabetDemod(modulation, 0)
  108. x = np.random.rand(samples, 2) * 2 - 1
  109. x = x.reshape((-1, 2))
  110. f = np.zeros(x.shape[0])
  111. xf = np.c_[x[:, 0], x[:, 1], f]
  112. y = demod.forward(misc.rect2polar(xf))
  113. y_ho = misc.bit_matrix2one_hot(y.reshape((-1, 4)))
  114. X_train, X_test, y_train, y_test = train_test_split(x, y_ho)
  115. self.decoder.compile(optimizer='adam', loss=tf.keras.losses.MeanSquaredError())
  116. self.decoder.fit(X_train, y_train, shuffle=False, validation_data=(X_test, y_test))
  117. y_pred = autoencoder.decoder(X_test).numpy()
  118. y_pred2 = np.zeros(y_test.shape, dtype=bool)
  119. y_pred2[np.arange(y_pred2.shape[0]), np.argmax(y_pred, axis=1)] = True
  120. print("Accuracy: %.4f" % accuracy_score(y_pred2, y_test))
  121. def train(self, samples=1e6):
  122. if samples % self.N:
  123. samples += self.N - (samples % self.N)
  124. x_train = misc.generate_random_bit_array(samples).reshape((-1, self.N))
  125. x_train_ho = misc.bit_matrix2one_hot(x_train)
  126. test_samples = samples * 0.3
  127. if test_samples % self.N:
  128. test_samples += self.N - (test_samples % self.N)
  129. x_test_array = misc.generate_random_bit_array(test_samples)
  130. x_test = x_test_array.reshape((-1, self.N))
  131. x_test_ho = misc.bit_matrix2one_hot(x_test)
  132. if not self.compiled:
  133. self.compile(optimizer='adam', loss=losses.MeanSquaredError())
  134. self.compiled = True
  135. self.fit(x_train_ho, x_train_ho, shuffle=False, validation_data=(x_test_ho, x_test_ho))
  136. # encoded_data = self.encoder(x_test_ho)
  137. # decoded_data = self.decoder(encoded_data).numpy()
  138. def get_modulator(self):
  139. if self.mod is None:
  140. self.mod = AutoencoderMod(self)
  141. return self.mod
  142. def get_demodulator(self):
  143. if self.demod is None:
  144. self.demod = AutoencoderDemod(self)
  145. return self.demod
  146. def view_encoder(encoder, N, samples=1000):
  147. test_values = misc.generate_random_bit_array(samples).reshape((-1, N))
  148. test_values_ho = misc.bit_matrix2one_hot(test_values)
  149. mvector = np.array([2 ** i for i in range(N)], dtype=int)
  150. symbols = (test_values * mvector).sum(axis=1)
  151. encoded = encoder(test_values_ho).numpy()
  152. # encoded = misc.polar2rect(encoded)
  153. for i in range(2 ** N):
  154. xy = encoded[symbols == i]
  155. plt.plot(xy[:, 0], xy[:, 1], 'x', markersize=12, label=format(i, f'0{N}b'))
  156. plt.annotate(xy=[xy[:, 0].mean() + 0.01, xy[:, 1].mean() + 0.01], text=format(i, f'0{N}b'))
  157. plt.xlabel('Real')
  158. plt.ylabel('Imaginary')
  159. plt.title("Autoencoder generated alphabet")
  160. # plt.legend()
  161. plt.show()
  162. pass
  163. if __name__ == '__main__':
  164. # (x_train, _), (x_test, _) = fashion_mnist.load_data()
  165. #
  166. # x_train = x_train.astype('float32') / 255.
  167. # x_test = x_test.astype('float32') / 255.
  168. #
  169. # print(f"Train data: {x_train.shape}")
  170. # print(f"Test data: {x_test.shape}")
  171. n = 4
  172. # samples = 1e6
  173. # x_train = misc.generate_random_bit_array(samples).reshape((-1, n))
  174. # x_train_ho = misc.bit_matrix2one_hot(x_train)
  175. # x_test_array = misc.generate_random_bit_array(samples * 0.3)
  176. # x_test = x_test_array.reshape((-1, n))
  177. # x_test_ho = misc.bit_matrix2one_hot(x_test)
  178. autoencoder = Autoencoder(n, -8)
  179. autoencoder.compile(optimizer='adam', loss=losses.MeanSquaredError())
  180. autoencoder.fit_encoder(modulation='16qam',
  181. sample_size=2e6,
  182. train_size=0.8,
  183. epochs=1,
  184. batch_size=256,
  185. shuffle=True)
  186. view_encoder(autoencoder.encoder, n)
  187. autoencoder.fit_decoder(modulation='16qam', samples=2e6)
  188. autoencoder.train()
  189. view_encoder(autoencoder.encoder, n)
  190. # view_encoder(autoencoder.encoder, n)
  191. # view_encoder(autoencoder.encoder, n)
  192. # autoencoder.compile(optimizer='adam', loss=losses.MeanSquaredError())
  193. #
  194. # autoencoder.fit(x_train_ho, x_train_ho,
  195. # epochs=1,
  196. # shuffle=False,
  197. # validation_data=(x_test_ho, x_test_ho))
  198. #
  199. # encoded_data = autoencoder.encoder(x_test_ho)
  200. # decoded_data = autoencoder.decoder(encoded_data).numpy()
  201. #
  202. # result = misc.int2bit_array(decoded_data.argmax(axis=1), n)
  203. # print("Accuracy: %.4f" % accuracy_score(x_test_array, result.reshape(-1, )))
  204. # view_encoder(autoencoder.encoder, n)
  205. pass