autoencoder.py 9.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258
  1. import matplotlib.pyplot as plt
  2. import numpy as np
  3. import tensorflow as tf
  4. from sklearn.metrics import accuracy_score
  5. from sklearn.model_selection import train_test_split
  6. from tensorflow.keras import layers, losses
  7. from tensorflow.keras.models import Model
  8. from tensorflow.python.keras.layers import LeakyReLU, ReLU
  9. from functools import partial
  10. import misc
  11. import defs
  12. from models import basic
  13. import os
  14. latent_dim = 64
  15. print("# GPUs Available: ", len(tf.config.experimental.list_physical_devices('GPU')))
  16. class AutoencoderMod(defs.Modulator):
  17. def __init__(self, autoencoder):
  18. super().__init__(2 ** autoencoder.N)
  19. self.autoencoder = autoencoder
  20. def forward(self, binary: np.ndarray) -> defs.Signal:
  21. reshaped = binary.reshape((-1, self.N))
  22. reshaped_ho = misc.bit_matrix2one_hot(reshaped)
  23. encoded = self.autoencoder.encoder(reshaped_ho)
  24. x = encoded.numpy()
  25. x2 = x * 2 - 1
  26. f = np.zeros(x2.shape[0])
  27. x3 = misc.rect2polar(np.c_[x2[:, 0], x2[:, 1], f])
  28. return defs.Signal(x3)
  29. class AutoencoderDemod(defs.Demodulator):
  30. def __init__(self, autoencoder):
  31. super().__init__(2 ** autoencoder.N)
  32. self.autoencoder = autoencoder
  33. def forward(self, values: defs.Signal) -> np.ndarray:
  34. decoded = self.autoencoder.decoder(values.rect).numpy()
  35. result = misc.int2bit_array(decoded.argmax(axis=1), self.N)
  36. return result.reshape(-1, )
  37. class Autoencoder(Model):
  38. def __init__(self, N, channel, signal_dim=2):
  39. super(Autoencoder, self).__init__()
  40. self.N = N
  41. self.encoder = tf.keras.Sequential()
  42. self.encoder.add(tf.keras.Input(shape=(2 ** N,), dtype=bool))
  43. self.encoder.add(layers.Dense(units=2 ** (N + 1)))
  44. self.encoder.add(LeakyReLU(alpha=0.001))
  45. # self.encoder.add(layers.Dropout(0.2))
  46. self.encoder.add(layers.Dense(units=2 ** (N + 1)))
  47. self.encoder.add(LeakyReLU(alpha=0.001))
  48. self.encoder.add(layers.Dense(units=signal_dim, activation="tanh"))
  49. # self.encoder.add(layers.ReLU(max_value=1.0))
  50. self.decoder = tf.keras.Sequential()
  51. self.decoder.add(tf.keras.Input(shape=(signal_dim,)))
  52. self.decoder.add(layers.Dense(units=2 ** (N + 1)))
  53. # leaky relu with alpha=1 gives by far best results
  54. self.decoder.add(LeakyReLU(alpha=1))
  55. self.decoder.add(layers.Dense(units=2 ** N, activation="softmax"))
  56. # self.randomiser = tf.random_normal_initializer(mean=0.0, stddev=0.1, seed=None)
  57. self.mod = None
  58. self.demod = None
  59. self.compiled = False
  60. if isinstance(channel, int) or isinstance(channel, float):
  61. self.channel = basic.AWGNChannel(channel)
  62. else:
  63. if not hasattr(channel, 'forward_tensor'):
  64. raise ValueError("Channel has no forward_tensor function")
  65. if not callable(channel.forward_tensor):
  66. raise ValueError("Channel.forward_tensor is not callable")
  67. self.channel = channel
  68. # self.decoder.add(layers.Softmax(units=4, dtype=bool))
  69. # [
  70. # layers.Input(shape=(28, 28, 1)),
  71. # layers.Conv2D(16, (3, 3), activation='relu', padding='same', strides=2),
  72. # layers.Conv2D(8, (3, 3), activation='relu', padding='same', strides=2)
  73. # ])
  74. # self.decoder = tf.keras.Sequential([
  75. # layers.Conv2DTranspose(8, kernel_size=3, strides=2, activation='relu', padding='same'),
  76. # layers.Conv2DTranspose(16, kernel_size=3, strides=2, activation='relu', padding='same'),
  77. # layers.Conv2D(1, kernel_size=(3, 3), activation='sigmoid', padding='same')
  78. # ])
  79. def call(self, x, **kwargs):
  80. signal = self.encoder(x)
  81. signal = signal * 2 - 1
  82. signal = self.channel.forward_tensor(signal)
  83. # encoded = encoded * 2 - 1
  84. # encoded = tf.clip_by_value(encoded, clip_value_min=0, clip_value_max=1, name=None)
  85. # noise = self.randomiser(shape=(-1, 2), dtype=tf.float32)
  86. # noise = np.random.normal(0, 1, (1, 2)) * self.noise
  87. # noisy = tf.convert_to_tensor(noise, dtype=tf.float32)
  88. decoded = self.decoder(signal)
  89. return decoded
  90. def fit_encoder(self, modulation, sample_size, train_size=0.8, epochs=1, batch_size=1, shuffle=False):
  91. alphabet = basic.load_alphabet(modulation, polar=False)
  92. if not alphabet.shape[0] == self.N ** 2:
  93. raise Exception("Cardinality of modulation scheme is different from cardinality of autoencoder!")
  94. x_train = np.random.randint(self.N ** 2, size=int(sample_size * train_size))
  95. y_train = alphabet[x_train]
  96. x_train_ho = np.zeros((int(sample_size * train_size), self.N ** 2))
  97. for idx, x in np.ndenumerate(x_train):
  98. x_train_ho[idx, x] = 1
  99. x_test = np.random.randint(self.N ** 2, size=int(sample_size * (1 - train_size)))
  100. y_test = alphabet[x_test]
  101. x_test_ho = np.zeros((int(sample_size * (1 - train_size)), self.N ** 2))
  102. for idx, x in np.ndenumerate(x_test):
  103. x_test_ho[idx, x] = 1
  104. self.encoder.compile(optimizer='adam', loss=tf.keras.losses.MeanSquaredError())
  105. self.encoder.fit(x_train_ho, y_train,
  106. epochs=epochs,
  107. batch_size=batch_size,
  108. shuffle=shuffle,
  109. validation_data=(x_test_ho, y_test))
  110. def fit_decoder(self, modulation, samples):
  111. samples = int(samples * 1.3)
  112. demod = basic.AlphabetDemod(modulation, 0)
  113. x = np.random.rand(samples, 2) * 2 - 1
  114. x = x.reshape((-1, 2))
  115. f = np.zeros(x.shape[0])
  116. xf = np.c_[x[:, 0], x[:, 1], f]
  117. y = demod.forward(defs.Signal(misc.rect2polar(xf)))
  118. y_ho = misc.bit_matrix2one_hot(y.reshape((-1, 4)))
  119. X_train, X_test, y_train, y_test = train_test_split(x, y_ho)
  120. self.decoder.compile(optimizer='adam', loss=tf.keras.losses.MeanSquaredError())
  121. self.decoder.fit(X_train, y_train, shuffle=False, validation_data=(X_test, y_test))
  122. y_pred = self.decoder(X_test).numpy()
  123. y_pred2 = np.zeros(y_test.shape, dtype=bool)
  124. y_pred2[np.arange(y_pred2.shape[0]), np.argmax(y_pred, axis=1)] = True
  125. print("Decoder accuracy: %.4f" % accuracy_score(y_pred2, y_test))
  126. def train(self, samples=1e6):
  127. if samples % self.N:
  128. samples += self.N - (samples % self.N)
  129. x_train = misc.generate_random_bit_array(samples).reshape((-1, self.N))
  130. x_train_ho = misc.bit_matrix2one_hot(x_train)
  131. test_samples = samples * 0.3
  132. if test_samples % self.N:
  133. test_samples += self.N - (test_samples % self.N)
  134. x_test_array = misc.generate_random_bit_array(test_samples)
  135. x_test = x_test_array.reshape((-1, self.N))
  136. x_test_ho = misc.bit_matrix2one_hot(x_test)
  137. if not self.compiled:
  138. self.compile(optimizer='adam', loss=losses.MeanSquaredError())
  139. self.compiled = True
  140. self.fit(x_train_ho, x_train_ho, shuffle=False, validation_data=(x_test_ho, x_test_ho))
  141. # encoded_data = self.encoder(x_test_ho)
  142. # decoded_data = self.decoder(encoded_data).numpy()
  143. def get_modulator(self):
  144. if self.mod is None:
  145. self.mod = AutoencoderMod(self)
  146. return self.mod
  147. def get_demodulator(self):
  148. if self.demod is None:
  149. self.demod = AutoencoderDemod(self)
  150. return self.demod
  151. def view_encoder(encoder, N, samples=1000):
  152. test_values = misc.generate_random_bit_array(samples).reshape((-1, N))
  153. test_values_ho = misc.bit_matrix2one_hot(test_values)
  154. mvector = np.array([2 ** i for i in range(N)], dtype=int)
  155. symbols = (test_values * mvector).sum(axis=1)
  156. encoded = encoder(test_values_ho).numpy()
  157. # encoded = misc.polar2rect(encoded)
  158. for i in range(2 ** N):
  159. xy = encoded[symbols == i]
  160. plt.plot(xy[:, 0], xy[:, 1], 'x', markersize=12, label=format(i, f'0{N}b'))
  161. plt.annotate(xy=[xy[:, 0].mean() + 0.01, xy[:, 1].mean() + 0.01], text=format(i, f'0{N}b'))
  162. plt.xlabel('Real')
  163. plt.ylabel('Imaginary')
  164. plt.title("Autoencoder generated alphabet")
  165. # plt.legend()
  166. plt.show()
  167. pass
  168. if __name__ == '__main__':
  169. # (x_train, _), (x_test, _) = fashion_mnist.load_data()
  170. #
  171. # x_train = x_train.astype('float32') / 255.
  172. # x_test = x_test.astype('float32') / 255.
  173. #
  174. # print(f"Train data: {x_train.shape}")
  175. # print(f"Test data: {x_test.shape}")
  176. n = 4
  177. # samples = 1e6
  178. # x_train = misc.generate_random_bit_array(samples).reshape((-1, n))
  179. # x_train_ho = misc.bit_matrix2one_hot(x_train)
  180. # x_test_array = misc.generate_random_bit_array(samples * 0.3)
  181. # x_test = x_test_array.reshape((-1, n))
  182. # x_test_ho = misc.bit_matrix2one_hot(x_test)
  183. autoencoder = Autoencoder(n, -8)
  184. autoencoder.compile(optimizer='adam', loss=losses.MeanSquaredError())
  185. autoencoder.fit_encoder(modulation='16qam',
  186. sample_size=2e6,
  187. train_size=0.8,
  188. epochs=1,
  189. batch_size=256,
  190. shuffle=True)
  191. view_encoder(autoencoder.encoder, n)
  192. autoencoder.fit_decoder(modulation='16qam', samples=2e6)
  193. autoencoder.train()
  194. view_encoder(autoencoder.encoder, n)
  195. # view_encoder(autoencoder.encoder, n)
  196. # view_encoder(autoencoder.encoder, n)
  197. # autoencoder.compile(optimizer='adam', loss=losses.MeanSquaredError())
  198. #
  199. # autoencoder.fit(x_train_ho, x_train_ho,
  200. # epochs=1,
  201. # shuffle=False,
  202. # validation_data=(x_test_ho, x_test_ho))
  203. #
  204. # encoded_data = autoencoder.encoder(x_test_ho)
  205. # decoded_data = autoencoder.decoder(encoded_data).numpy()
  206. #
  207. # result = misc.int2bit_array(decoded_data.argmax(axis=1), n)
  208. # print("Accuracy: %.4f" % accuracy_score(x_test_array, result.reshape(-1, )))
  209. # view_encoder(autoencoder.encoder, n)
  210. pass