autoencoder.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315
  1. import matplotlib.pyplot as plt
  2. import numpy as np
  3. import tensorflow as tf
  4. from sklearn.metrics import accuracy_score
  5. from sklearn.model_selection import train_test_split
  6. from tensorflow.keras import layers, losses
  7. from tensorflow.keras.models import Model
  8. from tensorflow.python.keras.layers import LeakyReLU, ReLU
  9. # from functools import partial
  10. import misc
  11. import defs
  12. from models import basic
  13. import os
  14. # from tensorflow_model_optimization.python.core.quantization.keras import quantize, quantize_aware_activation
  15. from models.data import BinaryOneHotGenerator
  16. from models import layers as custom_layers
  17. latent_dim = 64
  18. print("# GPUs Available: ", len(tf.config.experimental.list_physical_devices('GPU')))
  19. class AutoencoderMod(defs.Modulator):
  20. def __init__(self, autoencoder, encoder=None):
  21. super().__init__(2 ** autoencoder.N)
  22. self.autoencoder = autoencoder
  23. self.encoder = encoder or autoencoder.encoder
  24. def forward(self, binary: np.ndarray):
  25. reshaped = binary.reshape((-1, (self.N * self.autoencoder.parallel)))
  26. reshaped_ho = misc.bit_matrix2one_hot(reshaped)
  27. encoded = self.encoder(reshaped_ho)
  28. x = encoded.numpy()
  29. if self.autoencoder.bipolar:
  30. x = x * 2 - 1
  31. if self.autoencoder.parallel > 1:
  32. x = x.reshape((-1, self.autoencoder.signal_dim))
  33. f = np.zeros(x.shape[0])
  34. if self.autoencoder.signal_dim <= 1:
  35. p = np.zeros(x.shape[0])
  36. else:
  37. p = x[:, 1]
  38. x3 = misc.rect2polar(np.c_[x[:, 0], p, f])
  39. return basic.RFSignal(x3)
  40. class AutoencoderDemod(defs.Demodulator):
  41. def __init__(self, autoencoder, decoder=None):
  42. super().__init__(2 ** autoencoder.N)
  43. self.autoencoder = autoencoder
  44. self.decoder = decoder or autoencoder.decoder
  45. def forward(self, values: defs.Signal) -> np.ndarray:
  46. if self.autoencoder.signal_dim <= 1:
  47. val = values.rect_x
  48. else:
  49. val = values.rect
  50. if self.autoencoder.parallel > 1:
  51. val = val.reshape((-1, self.autoencoder.parallel))
  52. decoded = self.decoder(val).numpy()
  53. result = misc.int2bit_array(decoded.argmax(axis=1), self.N * self.autoencoder.parallel)
  54. return result.reshape(-1, )
  55. class Autoencoder(Model):
  56. def __init__(self, N, channel,
  57. signal_dim=2,
  58. parallel=1,
  59. all_onehot=True,
  60. bipolar=True,
  61. encoder=None,
  62. decoder=None,
  63. data_generator=None,
  64. cost=None
  65. ):
  66. super(Autoencoder, self).__init__()
  67. self.N = N
  68. self.parallel = parallel
  69. self.signal_dim = signal_dim
  70. self.bipolar = bipolar
  71. self._input_shape = 2 ** (N * parallel) if all_onehot else (2 ** N) * parallel
  72. if encoder is None:
  73. self.encoder = tf.keras.Sequential()
  74. self.encoder.add(layers.Input(shape=(self._input_shape,)))
  75. self.encoder.add(layers.Dense(units=2 ** (N + 1)))
  76. self.encoder.add(LeakyReLU(alpha=0.001))
  77. # self.encoder.add(layers.Dropout(0.2))
  78. self.encoder.add(layers.Dense(units=2 ** (N + 1)))
  79. self.encoder.add(LeakyReLU(alpha=0.001))
  80. self.encoder.add(layers.Dense(units=signal_dim * parallel, activation="sigmoid"))
  81. # self.encoder.add(layers.ReLU(max_value=1.0))
  82. # self.encoder = quantize.quantize_model(self.encoder)
  83. else:
  84. self.encoder = encoder
  85. if decoder is None:
  86. self.decoder = tf.keras.Sequential()
  87. self.decoder.add(tf.keras.Input(shape=(signal_dim * parallel,)))
  88. self.decoder.add(layers.Dense(units=2 ** (N + 1)))
  89. # self.encoder.add(LeakyReLU(alpha=0.001))
  90. # self.decoder.add(layers.Dense(units=2 ** (N + 1)))
  91. # leaky relu with alpha=1 gives by far best results
  92. self.decoder.add(LeakyReLU(alpha=1))
  93. self.decoder.add(layers.Dense(units=self._input_shape, activation="softmax"))
  94. else:
  95. self.decoder = decoder
  96. # self.randomiser = tf.random_normal_initializer(mean=0.0, stddev=0.1, seed=None)
  97. self.mod = None
  98. self.demod = None
  99. self.compiled = False
  100. self.channel = tf.keras.Sequential()
  101. if self.bipolar:
  102. self.channel.add(custom_layers.ScaleAndOffset(2, -1, input_shape=(signal_dim * parallel,)))
  103. if isinstance(channel, int) or isinstance(channel, float):
  104. self.channel.add(custom_layers.AwgnChannel(noise_dB=channel, input_shape=(signal_dim * parallel,)))
  105. else:
  106. if not isinstance(channel, tf.keras.layers.Layer):
  107. raise ValueError("Channel is not a keras layer")
  108. self.channel.add(channel)
  109. self.data_generator = data_generator
  110. if data_generator is None:
  111. self.data_generator = BinaryOneHotGenerator
  112. self.cost = cost
  113. if cost is None:
  114. self.cost = losses.MeanSquaredError()
  115. # self.decoder.add(layers.Softmax(units=4, dtype=bool))
  116. # [
  117. # layers.Input(shape=(28, 28, 1)),
  118. # layers.Conv2D(16, (3, 3), activation='relu', padding='same', strides=2),
  119. # layers.Conv2D(8, (3, 3), activation='relu', padding='same', strides=2)
  120. # ])
  121. # self.decoder = tf.keras.Sequential([
  122. # layers.Conv2DTranspose(8, kernel_size=3, strides=2, activation='relu', padding='same'),
  123. # layers.Conv2DTranspose(16, kernel_size=3, strides=2, activation='relu', padding='same'),
  124. # layers.Conv2D(1, kernel_size=(3, 3), activation='sigmoid', padding='same')
  125. # ])
  126. @property
  127. def all_layers(self):
  128. return self.encoder.layers + self.decoder.layers #self.channel.layers +
  129. def call(self, x, **kwargs):
  130. y = self.encoder(x)
  131. z = self.channel(y)
  132. return self.decoder(z)
  133. def fit_encoder(self, modulation, sample_size, train_size=0.8, epochs=1, batch_size=1, shuffle=False):
  134. alphabet = basic.load_alphabet(modulation, polar=False)
  135. if not alphabet.shape[0] == self.N ** 2:
  136. raise Exception("Cardinality of modulation scheme is different from cardinality of autoencoder!")
  137. x_train = np.random.randint(self.N ** 2, size=int(sample_size * train_size))
  138. y_train = alphabet[x_train]
  139. x_train_ho = np.zeros((int(sample_size * train_size), self.N ** 2))
  140. for idx, x in np.ndenumerate(x_train):
  141. x_train_ho[idx, x] = 1
  142. x_test = np.random.randint(self.N ** 2, size=int(sample_size * (1 - train_size)))
  143. y_test = alphabet[x_test]
  144. x_test_ho = np.zeros((int(sample_size * (1 - train_size)), self.N ** 2))
  145. for idx, x in np.ndenumerate(x_test):
  146. x_test_ho[idx, x] = 1
  147. self.encoder.compile(optimizer='adam', loss=tf.keras.losses.MeanSquaredError())
  148. self.encoder.fit(x_train_ho, y_train,
  149. epochs=epochs,
  150. batch_size=batch_size,
  151. shuffle=shuffle,
  152. validation_data=(x_test_ho, y_test))
  153. def fit_decoder(self, modulation, samples):
  154. samples = int(samples * 1.3)
  155. demod = basic.AlphabetDemod(modulation, 0)
  156. x = np.random.rand(samples, 2) * 2 - 1
  157. x = x.reshape((-1, 2))
  158. f = np.zeros(x.shape[0])
  159. xf = np.c_[x[:, 0], x[:, 1], f]
  160. y = demod.forward(basic.RFSignal(misc.rect2polar(xf)))
  161. y_ho = misc.bit_matrix2one_hot(y.reshape((-1, 4)))
  162. X_train, X_test, y_train, y_test = train_test_split(x, y_ho)
  163. self.decoder.compile(optimizer='adam', loss=tf.keras.losses.MeanSquaredError())
  164. self.decoder.fit(X_train, y_train, shuffle=False, validation_data=(X_test, y_test))
  165. y_pred = self.decoder(X_test).numpy()
  166. y_pred2 = np.zeros(y_test.shape, dtype=bool)
  167. y_pred2[np.arange(y_pred2.shape[0]), np.argmax(y_pred, axis=1)] = True
  168. print("Decoder accuracy: %.4f" % accuracy_score(y_pred2, y_test))
  169. def train(self, epoch_size=3e3, epochs=5, callbacks=None, optimizer='adam', metrics=None):
  170. m = self.N * self.parallel
  171. x_train = self.data_generator(size=epoch_size, shape=m)
  172. x_test = self.data_generator(size=epoch_size*.3, shape=m)
  173. # test_samples = epoch_size
  174. # if test_samples % m:
  175. # test_samples += m - (test_samples % m)
  176. # x_test_array = misc.generate_random_bit_array(test_samples)
  177. # x_test = x_test_array.reshape((-1, m))
  178. # x_test_ho = misc.bit_matrix2one_hot(x_test)
  179. if not self.compiled:
  180. self.compile(
  181. optimizer=optimizer,
  182. loss=self.cost,
  183. metrics=metrics
  184. )
  185. self.compiled = True
  186. # self.build((self._input_shape, -1))
  187. # self.summary()
  188. history = self.fit(
  189. x_train, shuffle=False,
  190. validation_data=x_test, epochs=epochs,
  191. callbacks=callbacks,
  192. )
  193. return history
  194. # encoded_data = self.encoder(x_test_ho)
  195. # decoded_data = self.decoder(encoded_data).numpy()
  196. def get_modulator(self):
  197. if self.mod is None:
  198. self.mod = AutoencoderMod(self)
  199. return self.mod
  200. def get_demodulator(self):
  201. if self.demod is None:
  202. self.demod = AutoencoderDemod(self)
  203. return self.demod
  204. def view_encoder(encoder, N, samples=1000, title="Autoencoder generated alphabet"):
  205. test_values = misc.generate_random_bit_array(samples*N).reshape((-1, N))
  206. test_values_ho = misc.bit_matrix2one_hot(test_values)
  207. mvector = np.array([2 ** i for i in range(N)], dtype=int)
  208. symbols = (test_values * mvector).sum(axis=1)
  209. encoded = encoder(test_values_ho).numpy()
  210. if encoded.shape[1] == 1:
  211. encoded = np.c_[encoded, np.zeros(encoded.shape[0])]
  212. # encoded = misc.polar2rect(encoded)
  213. for i in range(2 ** N):
  214. xy = encoded[symbols == i]
  215. plt.plot(xy[:, 0], xy[:, 1], 'x', markersize=12, label=format(i, f'0{N}b'))
  216. plt.annotate(xy=[xy[:, 0].mean() + 0.01, xy[:, 1].mean() + 0.01], text=format(i, f'0{N}b'))
  217. plt.xlabel('Real')
  218. plt.ylabel('Imaginary')
  219. plt.title(title)
  220. # plt.legend()
  221. plt.show()
  222. pass
  223. if __name__ == '__main__':
  224. # (x_train, _), (x_test, _) = fashion_mnist.load_data()
  225. #
  226. # x_train = x_train.astype('float32') / 255.
  227. # x_test = x_test.astype('float32') / 255.
  228. #
  229. # print(f"Train data: {x_train.shape}")
  230. # print(f"Test data: {x_test.shape}")
  231. n = 4
  232. # samples = 1e6
  233. # x_train = misc.generate_random_bit_array(samples).reshape((-1, n))
  234. # x_train_ho = misc.bit_matrix2one_hot(x_train)
  235. # x_test_array = misc.generate_random_bit_array(samples * 0.3)
  236. # x_test = x_test_array.reshape((-1, n))
  237. # x_test_ho = misc.bit_matrix2one_hot(x_test)
  238. autoencoder = Autoencoder(n, -15)
  239. autoencoder.compile(optimizer='adam', loss=losses.MeanSquaredError())
  240. # autoencoder.fit_encoder(modulation='16qam',
  241. # sample_size=2e6,
  242. # train_size=0.8,
  243. # epochs=1,
  244. # batch_size=256,
  245. # shuffle=True)
  246. # view_encoder(autoencoder.encoder, n)
  247. # autoencoder.fit_decoder(modulation='16qam', samples=2e6)
  248. autoencoder.train()
  249. view_encoder(autoencoder.encoder, n)
  250. # view_encoder(autoencoder.encoder, n)
  251. # view_encoder(autoencoder.encoder, n)
  252. # autoencoder.compile(optimizer='adam', loss=losses.MeanSquaredError())
  253. #
  254. # autoencoder.fit(x_train_ho, x_train_ho,
  255. # epochs=1,
  256. # shuffle=False,
  257. # validation_data=(x_test_ho, x_test_ho))
  258. #
  259. # encoded_data = autoencoder.encoder(x_test_ho)
  260. # decoded_data = autoencoder.decoder(encoded_data).numpy()
  261. #
  262. # result = misc.int2bit_array(decoded_data.argmax(axis=1), n)
  263. # print("Accuracy: %.4f" % accuracy_score(x_test_array, result.reshape(-1, )))
  264. # view_encoder(autoencoder.encoder, n)
  265. pass