|
|
@@ -8,9 +8,10 @@ from itertools import chain
|
|
|
from sys import stdout
|
|
|
|
|
|
from tensorflow.python.framework.errors_impl import NotFoundError
|
|
|
+from tensorflow.keras import backend as K
|
|
|
|
|
|
import defs
|
|
|
-from graphs import get_SNR, get_AWGN_ber
|
|
|
+from graphs import get_SNR, get_AWGN_ber, show_train_history
|
|
|
from models import basic
|
|
|
from models.autoencoder import Autoencoder, view_encoder
|
|
|
import matplotlib.pyplot as plt
|
|
|
@@ -19,13 +20,12 @@ import misc
|
|
|
import numpy as np
|
|
|
|
|
|
from models.basic import AlphabetDemod, AlphabetMod
|
|
|
+from models.data import BinaryGenerator
|
|
|
+from models.layers import BitsToSymbols, SymbolsToBits
|
|
|
from models.optical_channel import OpticalChannel
|
|
|
from models.quantized_net import QuantizedNeuralNetwork
|
|
|
|
|
|
|
|
|
-
|
|
|
-
|
|
|
-
|
|
|
def _test_optics_autoencoder():
|
|
|
ch = OpticalChannel(
|
|
|
noise_level=-10,
|
|
|
@@ -71,7 +71,6 @@ def _test_optics_autoencoder():
|
|
|
plt.show()
|
|
|
|
|
|
|
|
|
-
|
|
|
def _test_autoencoder_pretrain():
|
|
|
# aenc = Autoencoder(4, -25)
|
|
|
# aenc.train(samples=1e6)
|
|
|
@@ -232,9 +231,10 @@ def _test_autoencoder_perf():
|
|
|
aenc.train(epoch_size=1e3, epochs=10)
|
|
|
# #
|
|
|
m = aenc.N * aenc.parallel
|
|
|
- x_train = misc.bit_matrix2one_hot(misc.generate_random_bit_array(100*m).reshape((-1, m)))
|
|
|
+ x_train = misc.bit_matrix2one_hot(misc.generate_random_bit_array(100 * m).reshape((-1, m)))
|
|
|
x_train_enc = aenc.encoder(x_train)
|
|
|
x_train = tf.cast(x_train, tf.float32)
|
|
|
+
|
|
|
#
|
|
|
# plt.plot(*get_SNR(
|
|
|
# aenc.get_modulator(),
|
|
|
@@ -261,6 +261,7 @@ def _test_autoencoder_perf():
|
|
|
def representative_data_gen():
|
|
|
for input_value in tf.data.Dataset.from_tensor_slices(train_x).batch(1).take(100):
|
|
|
yield [input_value]
|
|
|
+
|
|
|
converter.representative_dataset = representative_data_gen
|
|
|
tflite_model = converter.convert()
|
|
|
tflite_models_dir = pathlib.Path("/tmp/tflite/")
|
|
|
@@ -394,23 +395,29 @@ def _test_autoencoder_perf():
|
|
|
def _test_autoencoder_perf2():
|
|
|
aenc = Autoencoder(2, -20)
|
|
|
aenc.train(samples=3e6)
|
|
|
- plt.plot(*get_SNR(aenc.get_modulator(), aenc.get_demodulator(), ber_func=get_AWGN_ber, samples=100000, steps=50, start=-5, stop=15), '-', label='2Bit AE')
|
|
|
+ plt.plot(*get_SNR(aenc.get_modulator(), aenc.get_demodulator(), ber_func=get_AWGN_ber, samples=100000, steps=50,
|
|
|
+ start=-5, stop=15), '-', label='2Bit AE')
|
|
|
|
|
|
aenc = Autoencoder(3, -20)
|
|
|
aenc.train(samples=3e6)
|
|
|
- plt.plot(*get_SNR(aenc.get_modulator(), aenc.get_demodulator(), ber_func=get_AWGN_ber, samples=100000, steps=50, start=-5, stop=15), '-', label='3Bit AE')
|
|
|
+ plt.plot(*get_SNR(aenc.get_modulator(), aenc.get_demodulator(), ber_func=get_AWGN_ber, samples=100000, steps=50,
|
|
|
+ start=-5, stop=15), '-', label='3Bit AE')
|
|
|
|
|
|
aenc = Autoencoder(4, -20)
|
|
|
aenc.train(samples=3e6)
|
|
|
- plt.plot(*get_SNR(aenc.get_modulator(), aenc.get_demodulator(), ber_func=get_AWGN_ber, samples=100000, steps=50, start=-5, stop=15), '-', label='4Bit AE')
|
|
|
+ plt.plot(*get_SNR(aenc.get_modulator(), aenc.get_demodulator(), ber_func=get_AWGN_ber, samples=100000, steps=50,
|
|
|
+ start=-5, stop=15), '-', label='4Bit AE')
|
|
|
|
|
|
aenc = Autoencoder(5, -20)
|
|
|
aenc.train(samples=3e6)
|
|
|
- plt.plot(*get_SNR(aenc.get_modulator(), aenc.get_demodulator(), ber_func=get_AWGN_ber, samples=100000, steps=50, start=-5, stop=15), '-', label='5Bit AE')
|
|
|
+ plt.plot(*get_SNR(aenc.get_modulator(), aenc.get_demodulator(), ber_func=get_AWGN_ber, samples=100000, steps=50,
|
|
|
+ start=-5, stop=15), '-', label='5Bit AE')
|
|
|
|
|
|
for a in ['qpsk', '8psk', '16qam', '32qam', '64qam']:
|
|
|
try:
|
|
|
- plt.plot(*get_SNR(AlphabetMod(a, 10e6), AlphabetDemod(a, 10e6), ber_func=get_AWGN_ber, samples=100000, steps=50, start=-5, stop=15,), '-', label=a.upper())
|
|
|
+ plt.plot(
|
|
|
+ *get_SNR(AlphabetMod(a, 10e6), AlphabetDemod(a, 10e6), ber_func=get_AWGN_ber, samples=100000, steps=50,
|
|
|
+ start=-5, stop=15, ), '-', label=a.upper())
|
|
|
except KeyboardInterrupt:
|
|
|
break
|
|
|
except Exception:
|
|
|
@@ -455,8 +462,8 @@ def _test_autoencoder_perf_qnn():
|
|
|
view_encoder(aenc.encoder, 4, title="FP32 Alphabet")
|
|
|
|
|
|
batch_size = 25000
|
|
|
- x_train = misc.bit_matrix2one_hot(misc.generate_random_bit_array(batch_size*m).reshape((-1, m)))
|
|
|
- x_test = misc.bit_matrix2one_hot(misc.generate_random_bit_array(5000*m).reshape((-1, m)))
|
|
|
+ x_train = misc.bit_matrix2one_hot(misc.generate_random_bit_array(batch_size * m).reshape((-1, m)))
|
|
|
+ x_test = misc.bit_matrix2one_hot(misc.generate_random_bit_array(5000 * m).reshape((-1, m)))
|
|
|
bits = [np.log2(i) for i in (32,)][0]
|
|
|
alphabet_scalars = 2 # [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
|
|
|
num_layers = sum([layer.__class__.__name__ in ('Dense',) for layer in aenc.all_layers])
|
|
|
@@ -479,45 +486,99 @@ def _test_autoencoder_perf_qnn():
|
|
|
qnn.quantized_net.compile(optimizer='adam', loss=tf.keras.losses.MeanSquaredError())
|
|
|
view_encoder(qnn.quantized_net, 4, title=f"Quantised {16}b alphabet")
|
|
|
|
|
|
+ # view_encoder(qnn_enc.quantized_net, 4, title=f"Quantised {b}b alphabet")
|
|
|
+ # _, q_accuracy = qnn.quantized_net.evaluate(x_test, x_test, verbose=True)
|
|
|
+ pass
|
|
|
|
|
|
|
|
|
+class BitAwareAutoencoder(Autoencoder):
|
|
|
+ def __init__(self, N, channel, **kwargs):
|
|
|
+ super().__init__(N, channel, **kwargs, cost=self.cost)
|
|
|
+ self.BITS = 2 ** N - 1
|
|
|
+ # data_generator=BinaryGenerator,
|
|
|
+ # self.b2s_layer = BitsToSymbols(2**N)
|
|
|
+ # self.s2b_layer = SymbolsToBits(2**N)
|
|
|
+
|
|
|
+ def cost(self, y_true, y_pred):
|
|
|
+ y = tf.cast(y_true, dtype=tf.float32)
|
|
|
+ z0 = tf.math.argmax(y) / self.BITS
|
|
|
+ z1 = tf.math.argmax(y_pred) / self.BITS
|
|
|
+ error0 = y - y_pred
|
|
|
+ sqr_error0 = K.square(error0) # mean of the square of the error
|
|
|
+ mean_sqr_error0 = K.mean(sqr_error0) # square root of the mean of the square of the error
|
|
|
+ sme0 = K.sqrt(mean_sqr_error0) # return the error
|
|
|
+
|
|
|
+ error1 = z0 - z1
|
|
|
+ sqr_error1 = K.square(error1)
|
|
|
+ mean_sqr_error1 = K.mean(sqr_error1)
|
|
|
+ sme1 = K.sqrt(mean_sqr_error1)
|
|
|
+ return sme0 + tf.cast(sme1 * 300, dtype=tf.float32)
|
|
|
+
|
|
|
+ # def call(self, x, **kwargs):
|
|
|
+ # x1 = self.b2s_layer(x)
|
|
|
+ # y = self.encoder(x1)
|
|
|
+ # z = self.channel(y)
|
|
|
+ # z1 = self.decoder(z)
|
|
|
+ # return self.s2b_layer(z1)
|
|
|
+
|
|
|
+
|
|
|
+def _bit_aware_test():
|
|
|
+ aenc = BitAwareAutoencoder(6, -50, bipolar=True)
|
|
|
|
|
|
- # view_encoder(qnn_enc.quantized_net, 4, title=f"Quantised {b}b alphabet")
|
|
|
+ try:
|
|
|
+ aenc.load_weights('ae_bitaware')
|
|
|
+ except NotFoundError:
|
|
|
+ pass
|
|
|
|
|
|
+ # try:
|
|
|
+ # hist = aenc.train(
|
|
|
+ # epochs=70,
|
|
|
+ # epoch_size=1e3,
|
|
|
+ # optimizer='adam',
|
|
|
+ # # metrics=[tf.keras.metrics.Accuracy()],
|
|
|
+ # # callbacks=[tf.keras.callbacks.EarlyStopping(monitor='loss', patience=3, min_delta=0.001)]
|
|
|
+ # )
|
|
|
+ # show_train_history(hist, "Autonecoder training history")
|
|
|
+ # except KeyboardInterrupt:
|
|
|
+ # aenc.save_weights('ae_bitaware')
|
|
|
+ # exit(0)
|
|
|
+ #
|
|
|
+ # aenc.save_weights('ae_bitaware')
|
|
|
|
|
|
- # _, q_accuracy = qnn.quantized_net.evaluate(x_test, x_test, verbose=True)
|
|
|
- pass
|
|
|
+ view_encoder(aenc.encoder, 6, title=f"4bit autoencoder alphabet")
|
|
|
|
|
|
-if __name__ == '__main__':
|
|
|
+ print("Computing BER/SNR for autoencoder")
|
|
|
+ plt.plot(*get_SNR(
|
|
|
+ aenc.get_modulator(),
|
|
|
+ aenc.get_demodulator(),
|
|
|
+ ber_func=get_AWGN_ber,
|
|
|
+ samples=1000000, steps=40,
|
|
|
+ start=-5, stop=15), '-', label='4Bit AE')
|
|
|
|
|
|
+ print("Computing BER/SNR for QAM16")
|
|
|
+ plt.plot(*get_SNR(
|
|
|
+ AlphabetMod('64qam', 10e6),
|
|
|
+ AlphabetDemod('64qam', 10e6),
|
|
|
+ ber_func=get_AWGN_ber,
|
|
|
+ samples=1000000,
|
|
|
+ steps=40,
|
|
|
+ start=-5,
|
|
|
+ stop=15,
|
|
|
+ ), '-', label='16qam AWGN')
|
|
|
+
|
|
|
+ plt.yscale('log')
|
|
|
+ plt.grid()
|
|
|
+ plt.xlabel('SNR dB')
|
|
|
+ plt.ylabel('BER')
|
|
|
+ plt.title("16QAM vs autoencoder")
|
|
|
+ plt.show()
|
|
|
|
|
|
- # plt.plot(*get_SNR(
|
|
|
- # AlphabetMod('16qam', 10e6),
|
|
|
- # AlphabetDemod('16qam', 10e6),
|
|
|
- # ber_func=get_AWGN_ber,
|
|
|
- # samples=100000,
|
|
|
- # steps=50,
|
|
|
- # start=-5,
|
|
|
- # stop=15,
|
|
|
- # ), '-', label='16qam AWGN')
|
|
|
- #
|
|
|
- # plt.plot(*get_SNR(
|
|
|
- # AlphabetMod('16qam', 10e6),
|
|
|
- # AlphabetDemod('16qam', 10e6),
|
|
|
- # samples=100000,
|
|
|
- # steps=50,
|
|
|
- # start=-5,
|
|
|
- # stop=15,
|
|
|
- # ), '-', label='16qam OPTICAL')
|
|
|
- #
|
|
|
- # plt.yscale('log')
|
|
|
- # plt.grid()
|
|
|
- # plt.xlabel('SNR dB')
|
|
|
- # plt.show()
|
|
|
+
|
|
|
+if __name__ == '__main__':
|
|
|
+ _bit_aware_test()
|
|
|
|
|
|
# _test_autoencoder_perf()
|
|
|
- _test_autoencoder_perf_qnn()
|
|
|
+ # _test_autoencoder_perf_qnn()
|
|
|
# _test_autoencoder_perf2()
|
|
|
# _test_autoencoder_pretrain()
|
|
|
# _test_optics_autoencoder()
|
|
|
-
|