Bladeren bron

Custom cost function test

Min 4 jaren geleden
bovenliggende
commit
c2e69468bf
5 gewijzigde bestanden met toevoegingen van 186 en 49 verwijderingen
  1. 21 0
      graphs.py
  2. 9 1
      misc.py
  3. 31 6
      models/autoencoder.py
  4. 22 0
      models/data.py
  5. 103 42
      tests/min_test.py

+ 21 - 0
graphs.py

@@ -86,3 +86,24 @@ def get_SNR(mod, demod, ber_func=get_Optical_ber, samples=1000, start=-5, stop=1
     ber_x, ber_y = ber_func(mod, demod, samples, noise_start, noise_stop, **ber_kwargs)
     SNR = -ber_x + av_sig_pow
     return SNR, ber_y
+
+
+def show_train_history(history, title="", save=None):
+    from matplotlib import pyplot as plt
+
+    epochs = range(1, len(history.epoch) + 1)
+    if 'loss' in history.history:
+        plt.plot(epochs, history.history['loss'], label='Training Loss')
+    if 'accuracy' in history.history:
+        plt.plot(epochs, history.history['accuracy'], label='Training Accuracy')
+    if 'val_loss' in history.history:
+        plt.plot(epochs, history.history['val_loss'], label='Validation Loss')
+    if 'val_accuracy' in history.history:
+        plt.plot(epochs, history.history['val_accuracy'], label='Validation Accuracy')
+    plt.xlabel('Epochs')
+    plt.ylabel('Loss/Accuracy' if 'accuracy' in history.history else 'Loss')
+    plt.legend()
+    plt.title(title)
+    if save is not None:
+        plt.savefig(save)
+    plt.show()

+ 9 - 1
misc.py

@@ -2,7 +2,7 @@
 import numpy as np
 import math
 import matplotlib.pyplot as plt
-
+import pickle
 
 def display_alphabet(alphabet, values=None, a_vals=False, title="Alphabet constellation diagram"):
     rect = polar2rect(alphabet)
@@ -105,3 +105,11 @@ def generate_random_bit_array(size):
     return arr
 
 
+def picke_save(obj, fname):
+    with open(fname, 'wb') as f:
+        pickle.dump(obj, f)
+
+
+def picke_load(fname):
+    with open(fname, 'rb') as f:
+        return pickle.load(f)

+ 31 - 6
models/autoencoder.py

@@ -67,7 +67,16 @@ class AutoencoderDemod(defs.Demodulator):
 
 
 class Autoencoder(Model):
-    def __init__(self, N, channel, signal_dim=2, parallel=1, all_onehot=True, bipolar=True, encoder=None, decoder=None):
+    def __init__(self, N, channel,
+                 signal_dim=2,
+                 parallel=1,
+                 all_onehot=True,
+                 bipolar=True,
+                 encoder=None,
+                 decoder=None,
+                 data_generator=None,
+                 cost=None
+                 ):
         super(Autoencoder, self).__init__()
         self.N = N
         self.parallel = parallel
@@ -116,6 +125,13 @@ class Autoencoder(Model):
                 raise ValueError("Channel is not a keras layer")
             self.channel.add(channel)
 
+        self.data_generator = data_generator
+        if data_generator is None:
+            self.data_generator = BinaryOneHotGenerator
+
+        self.cost = cost
+        if cost is None:
+            self.cost = losses.MeanSquaredError()
         # self.decoder.add(layers.Softmax(units=4, dtype=bool))
 
         # [
@@ -181,10 +197,10 @@ class Autoencoder(Model):
 
         print("Decoder accuracy: %.4f" % accuracy_score(y_pred2, y_test))
 
-    def train(self, epoch_size=3e3, epochs=5):
+    def train(self, epoch_size=3e3, epochs=5, callbacks=None, optimizer='adam', metrics=None):
         m = self.N * self.parallel
-        x_train = BinaryOneHotGenerator(size=epoch_size, shape=m)
-        x_test = BinaryOneHotGenerator(size=epoch_size*.3, shape=m)
+        x_train = self.data_generator(size=epoch_size, shape=m)
+        x_test = self.data_generator(size=epoch_size*.3, shape=m)
 
         # test_samples = epoch_size
         # if test_samples % m:
@@ -194,12 +210,21 @@ class Autoencoder(Model):
         # x_test_ho = misc.bit_matrix2one_hot(x_test)
 
         if not self.compiled:
-            self.compile(optimizer='adam', loss=losses.MeanSquaredError())
+            self.compile(
+                optimizer=optimizer,
+                loss=self.cost,
+                metrics=metrics
+            )
             self.compiled = True
             # self.build((self._input_shape, -1))
             # self.summary()
 
-        self.fit(x_train, shuffle=False, validation_data=x_test, epochs=epochs)
+        history = self.fit(
+            x_train, shuffle=False,
+            validation_data=x_test, epochs=epochs,
+            callbacks=callbacks,
+        )
+        return history
         # encoded_data = self.encoder(x_test_ho)
         # decoded_data = self.decoder(encoded_data).numpy()
 

+ 22 - 0
models/data.py

@@ -30,3 +30,25 @@ class BinaryOneHotGenerator(Sequence):
 
     def __getitem__(self, idx):
         return self.x, self.x
+
+
+class BinaryGenerator(Sequence):
+    def __init__(self, size=1e5, shape=2, dtype=tf.bool):
+        size = int(size)
+        if size % shape:
+            size += shape - (size % shape)
+        self.size = size
+        self.shape = shape
+        self.x = None
+        self.dtype = dtype
+        self.on_epoch_end()
+
+    def on_epoch_end(self):
+        x = misc.generate_random_bit_array(self.size).reshape((-1, self.shape))
+        self.x = tf.convert_to_tensor(x, dtype=self.dtype)
+
+    def __len__(self):
+        return self.size
+
+    def __getitem__(self, idx):
+        return self.x, self.x

+ 103 - 42
tests/min_test.py

@@ -8,9 +8,10 @@ from itertools import chain
 from sys import stdout
 
 from tensorflow.python.framework.errors_impl import NotFoundError
+from tensorflow.keras import backend as K
 
 import defs
-from graphs import get_SNR, get_AWGN_ber
+from graphs import get_SNR, get_AWGN_ber, show_train_history
 from models import basic
 from models.autoencoder import Autoencoder, view_encoder
 import matplotlib.pyplot as plt
@@ -19,13 +20,12 @@ import misc
 import numpy as np
 
 from models.basic import AlphabetDemod, AlphabetMod
+from models.data import BinaryGenerator
+from models.layers import BitsToSymbols, SymbolsToBits
 from models.optical_channel import OpticalChannel
 from models.quantized_net import QuantizedNeuralNetwork
 
 
-
-
-
 def _test_optics_autoencoder():
     ch = OpticalChannel(
         noise_level=-10,
@@ -71,7 +71,6 @@ def _test_optics_autoencoder():
     plt.show()
 
 
-
 def _test_autoencoder_pretrain():
     # aenc = Autoencoder(4, -25)
     # aenc.train(samples=1e6)
@@ -232,9 +231,10 @@ def _test_autoencoder_perf():
     aenc.train(epoch_size=1e3, epochs=10)
     # #
     m = aenc.N * aenc.parallel
-    x_train = misc.bit_matrix2one_hot(misc.generate_random_bit_array(100*m).reshape((-1, m)))
+    x_train = misc.bit_matrix2one_hot(misc.generate_random_bit_array(100 * m).reshape((-1, m)))
     x_train_enc = aenc.encoder(x_train)
     x_train = tf.cast(x_train, tf.float32)
+
     #
     # plt.plot(*get_SNR(
     #     aenc.get_modulator(),
@@ -261,6 +261,7 @@ def _test_autoencoder_perf():
             def representative_data_gen():
                 for input_value in tf.data.Dataset.from_tensor_slices(train_x).batch(1).take(100):
                     yield [input_value]
+
             converter.representative_dataset = representative_data_gen
         tflite_model = converter.convert()
         tflite_models_dir = pathlib.Path("/tmp/tflite/")
@@ -394,23 +395,29 @@ def _test_autoencoder_perf():
 def _test_autoencoder_perf2():
     aenc = Autoencoder(2, -20)
     aenc.train(samples=3e6)
-    plt.plot(*get_SNR(aenc.get_modulator(), aenc.get_demodulator(), ber_func=get_AWGN_ber, samples=100000, steps=50, start=-5, stop=15), '-', label='2Bit AE')
+    plt.plot(*get_SNR(aenc.get_modulator(), aenc.get_demodulator(), ber_func=get_AWGN_ber, samples=100000, steps=50,
+                      start=-5, stop=15), '-', label='2Bit AE')
 
     aenc = Autoencoder(3, -20)
     aenc.train(samples=3e6)
-    plt.plot(*get_SNR(aenc.get_modulator(), aenc.get_demodulator(), ber_func=get_AWGN_ber, samples=100000, steps=50, start=-5, stop=15), '-', label='3Bit AE')
+    plt.plot(*get_SNR(aenc.get_modulator(), aenc.get_demodulator(), ber_func=get_AWGN_ber, samples=100000, steps=50,
+                      start=-5, stop=15), '-', label='3Bit AE')
 
     aenc = Autoencoder(4, -20)
     aenc.train(samples=3e6)
-    plt.plot(*get_SNR(aenc.get_modulator(), aenc.get_demodulator(), ber_func=get_AWGN_ber, samples=100000, steps=50, start=-5, stop=15), '-', label='4Bit AE')
+    plt.plot(*get_SNR(aenc.get_modulator(), aenc.get_demodulator(), ber_func=get_AWGN_ber, samples=100000, steps=50,
+                      start=-5, stop=15), '-', label='4Bit AE')
 
     aenc = Autoencoder(5, -20)
     aenc.train(samples=3e6)
-    plt.plot(*get_SNR(aenc.get_modulator(), aenc.get_demodulator(), ber_func=get_AWGN_ber, samples=100000, steps=50, start=-5, stop=15), '-', label='5Bit AE')
+    plt.plot(*get_SNR(aenc.get_modulator(), aenc.get_demodulator(), ber_func=get_AWGN_ber, samples=100000, steps=50,
+                      start=-5, stop=15), '-', label='5Bit AE')
 
     for a in ['qpsk', '8psk', '16qam', '32qam', '64qam']:
         try:
-            plt.plot(*get_SNR(AlphabetMod(a, 10e6), AlphabetDemod(a, 10e6), ber_func=get_AWGN_ber, samples=100000, steps=50, start=-5, stop=15,), '-', label=a.upper())
+            plt.plot(
+                *get_SNR(AlphabetMod(a, 10e6), AlphabetDemod(a, 10e6), ber_func=get_AWGN_ber, samples=100000, steps=50,
+                         start=-5, stop=15, ), '-', label=a.upper())
         except KeyboardInterrupt:
             break
         except Exception:
@@ -455,8 +462,8 @@ def _test_autoencoder_perf_qnn():
     view_encoder(aenc.encoder, 4, title="FP32 Alphabet")
 
     batch_size = 25000
-    x_train = misc.bit_matrix2one_hot(misc.generate_random_bit_array(batch_size*m).reshape((-1, m)))
-    x_test = misc.bit_matrix2one_hot(misc.generate_random_bit_array(5000*m).reshape((-1, m)))
+    x_train = misc.bit_matrix2one_hot(misc.generate_random_bit_array(batch_size * m).reshape((-1, m)))
+    x_test = misc.bit_matrix2one_hot(misc.generate_random_bit_array(5000 * m).reshape((-1, m)))
     bits = [np.log2(i) for i in (32,)][0]
     alphabet_scalars = 2  # [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
     num_layers = sum([layer.__class__.__name__ in ('Dense',) for layer in aenc.all_layers])
@@ -479,45 +486,99 @@ def _test_autoencoder_perf_qnn():
     qnn.quantized_net.compile(optimizer='adam', loss=tf.keras.losses.MeanSquaredError())
     view_encoder(qnn.quantized_net, 4, title=f"Quantised {16}b alphabet")
 
+    # view_encoder(qnn_enc.quantized_net, 4, title=f"Quantised {b}b alphabet")
+    # _, q_accuracy = qnn.quantized_net.evaluate(x_test, x_test, verbose=True)
+    pass
 
 
+class BitAwareAutoencoder(Autoencoder):
+    def __init__(self, N, channel, **kwargs):
+        super().__init__(N, channel, **kwargs, cost=self.cost)
+        self.BITS = 2 ** N - 1
+        #  data_generator=BinaryGenerator,
+        # self.b2s_layer = BitsToSymbols(2**N)
+        # self.s2b_layer = SymbolsToBits(2**N)
+
+    def cost(self, y_true, y_pred):
+        y = tf.cast(y_true, dtype=tf.float32)
+        z0 = tf.math.argmax(y) / self.BITS
+        z1 = tf.math.argmax(y_pred) / self.BITS
+        error0 = y - y_pred
+        sqr_error0 = K.square(error0)  # mean of the square of the error
+        mean_sqr_error0 = K.mean(sqr_error0)  # square root of the mean of the square of the error
+        sme0 = K.sqrt(mean_sqr_error0)  # return the error
+
+        error1 = z0 - z1
+        sqr_error1 = K.square(error1)
+        mean_sqr_error1 = K.mean(sqr_error1)
+        sme1 = K.sqrt(mean_sqr_error1)
+        return sme0 + tf.cast(sme1 * 300, dtype=tf.float32)
+
+    # def call(self, x, **kwargs):
+    #     x1 = self.b2s_layer(x)
+    #     y = self.encoder(x1)
+    #     z = self.channel(y)
+    #     z1 = self.decoder(z)
+    #     return self.s2b_layer(z1)
+
+
+def _bit_aware_test():
+    aenc = BitAwareAutoencoder(6, -50, bipolar=True)
 
-    # view_encoder(qnn_enc.quantized_net, 4, title=f"Quantised {b}b alphabet")
+    try:
+        aenc.load_weights('ae_bitaware')
+    except NotFoundError:
+        pass
 
+    # try:
+    #     hist = aenc.train(
+    #         epochs=70,
+    #         epoch_size=1e3,
+    #         optimizer='adam',
+    #         # metrics=[tf.keras.metrics.Accuracy()],
+    #         # callbacks=[tf.keras.callbacks.EarlyStopping(monitor='loss', patience=3, min_delta=0.001)]
+    #     )
+    #     show_train_history(hist, "Autonecoder training history")
+    # except KeyboardInterrupt:
+    #     aenc.save_weights('ae_bitaware')
+    #     exit(0)
+    #
+    # aenc.save_weights('ae_bitaware')
 
-    # _, q_accuracy = qnn.quantized_net.evaluate(x_test, x_test, verbose=True)
-    pass
+    view_encoder(aenc.encoder, 6, title=f"4bit autoencoder alphabet")
 
-if __name__ == '__main__':
+    print("Computing BER/SNR for autoencoder")
+    plt.plot(*get_SNR(
+        aenc.get_modulator(),
+        aenc.get_demodulator(),
+        ber_func=get_AWGN_ber,
+        samples=1000000, steps=40,
+        start=-5, stop=15), '-', label='4Bit AE')
 
+    print("Computing BER/SNR for QAM16")
+    plt.plot(*get_SNR(
+        AlphabetMod('64qam', 10e6),
+        AlphabetDemod('64qam', 10e6),
+        ber_func=get_AWGN_ber,
+        samples=1000000,
+        steps=40,
+        start=-5,
+        stop=15,
+    ), '-', label='16qam AWGN')
+
+    plt.yscale('log')
+    plt.grid()
+    plt.xlabel('SNR dB')
+    plt.ylabel('BER')
+    plt.title("16QAM vs autoencoder")
+    plt.show()
 
-    # plt.plot(*get_SNR(
-    #     AlphabetMod('16qam', 10e6),
-    #     AlphabetDemod('16qam', 10e6),
-    #     ber_func=get_AWGN_ber,
-    #     samples=100000,
-    #     steps=50,
-    #     start=-5,
-    #     stop=15,
-    # ), '-', label='16qam AWGN')
-    #
-    # plt.plot(*get_SNR(
-    #     AlphabetMod('16qam', 10e6),
-    #     AlphabetDemod('16qam', 10e6),
-    #     samples=100000,
-    #     steps=50,
-    #     start=-5,
-    #     stop=15,
-    # ), '-', label='16qam OPTICAL')
-    #
-    # plt.yscale('log')
-    # plt.grid()
-    # plt.xlabel('SNR dB')
-    # plt.show()
+
+if __name__ == '__main__':
+    _bit_aware_test()
 
     # _test_autoencoder_perf()
-    _test_autoencoder_perf_qnn()
+    # _test_autoencoder_perf_qnn()
     # _test_autoencoder_perf2()
     # _test_autoencoder_pretrain()
     # _test_optics_autoencoder()
-