|
@@ -1,143 +1,14 @@
|
|
|
|
|
+import json
|
|
|
import math
|
|
import math
|
|
|
-
|
|
|
|
|
|
|
+import os
|
|
|
|
|
+from datetime import datetime as dt
|
|
|
import tensorflow as tf
|
|
import tensorflow as tf
|
|
|
import numpy as np
|
|
import numpy as np
|
|
|
import matplotlib.pyplot as plt
|
|
import matplotlib.pyplot as plt
|
|
|
|
|
+from sklearn.metrics import accuracy_score
|
|
|
from sklearn.preprocessing import OneHotEncoder
|
|
from sklearn.preprocessing import OneHotEncoder
|
|
|
from tensorflow.keras import layers, losses
|
|
from tensorflow.keras import layers, losses
|
|
|
-
|
|
|
|
|
-
|
|
|
|
|
-class ExtractCentralMessage(layers.Layer):
|
|
|
|
|
- def __init__(self, messages_per_block, samples_per_symbol):
|
|
|
|
|
- """
|
|
|
|
|
- A keras layer that extracts the central message(symbol) in a block.
|
|
|
|
|
-
|
|
|
|
|
- :param messages_per_block: Total number of messages in transmission block
|
|
|
|
|
- :param samples_per_symbol: Number of samples per transmitted symbol
|
|
|
|
|
- """
|
|
|
|
|
- super(ExtractCentralMessage, self).__init__()
|
|
|
|
|
-
|
|
|
|
|
- temp_w = np.zeros((messages_per_block * samples_per_symbol, samples_per_symbol))
|
|
|
|
|
- i = np.identity(samples_per_symbol)
|
|
|
|
|
- begin = int(samples_per_symbol * ((messages_per_block - 1) / 2))
|
|
|
|
|
- end = int(samples_per_symbol * ((messages_per_block + 1) / 2))
|
|
|
|
|
- temp_w[begin:end, :] = i
|
|
|
|
|
-
|
|
|
|
|
- self.w = tf.convert_to_tensor(temp_w, dtype=tf.float32)
|
|
|
|
|
-
|
|
|
|
|
- def call(self, inputs, **kwargs):
|
|
|
|
|
- return tf.matmul(inputs, self.w)
|
|
|
|
|
-
|
|
|
|
|
-
|
|
|
|
|
-class AwgnChannel(layers.Layer):
|
|
|
|
|
- def __init__(self, rx_stddev=0.1):
|
|
|
|
|
- """
|
|
|
|
|
- A additive white gaussian noise channel model. The GaussianNoise class is utilized to prevent identical noise
|
|
|
|
|
- being applied every time the call function is called.
|
|
|
|
|
-
|
|
|
|
|
- :param rx_stddev: Standard deviation of receiver noise (due to e.g. TIA circuit)
|
|
|
|
|
- """
|
|
|
|
|
- super(AwgnChannel, self).__init__()
|
|
|
|
|
- self.noise_layer = layers.GaussianNoise(rx_stddev)
|
|
|
|
|
-
|
|
|
|
|
- def call(self, inputs, **kwargs):
|
|
|
|
|
- return self.noise_layer.call(inputs, training=True)
|
|
|
|
|
-
|
|
|
|
|
-
|
|
|
|
|
-class DigitizationLayer(layers.Layer):
|
|
|
|
|
- def __init__(self,
|
|
|
|
|
- fs,
|
|
|
|
|
- num_of_samples,
|
|
|
|
|
- lpf_cutoff=32e9,
|
|
|
|
|
- q_stddev=0.1):
|
|
|
|
|
- """
|
|
|
|
|
- This layer simulated the finite bandwidth of the hardware by means of a low pass filter. In addition to this,
|
|
|
|
|
- artefacts casued by quantization is modelled by the addition of white gaussian noise of a given stddev.
|
|
|
|
|
-
|
|
|
|
|
- :param fs: Sampling frequency of the simulation in Hz
|
|
|
|
|
- :param num_of_samples: Total number of samples in the input
|
|
|
|
|
- :param lpf_cutoff: Cutoff frequency of LPF modelling finite bandwidth in ADC/DAC
|
|
|
|
|
- :param q_stddev: Standard deviation of quantization noise at ADC/DAC
|
|
|
|
|
- """
|
|
|
|
|
- super(DigitizationLayer, self).__init__()
|
|
|
|
|
-
|
|
|
|
|
- self.noise_layer = layers.GaussianNoise(q_stddev)
|
|
|
|
|
- freq = np.fft.fftfreq(num_of_samples, d=1/fs)
|
|
|
|
|
- temp = np.ones(freq.shape)
|
|
|
|
|
-
|
|
|
|
|
- for idx, val in np.ndenumerate(freq):
|
|
|
|
|
- if np.abs(val) > lpf_cutoff:
|
|
|
|
|
- temp[idx] = 0
|
|
|
|
|
-
|
|
|
|
|
- self.lpf_multiplier = tf.convert_to_tensor(temp, dtype=tf.complex64)
|
|
|
|
|
-
|
|
|
|
|
- def call(self, inputs, **kwargs):
|
|
|
|
|
- complex_in = tf.cast(inputs, dtype=tf.complex64)
|
|
|
|
|
- val_f = tf.signal.fft(complex_in)
|
|
|
|
|
- filtered_f = tf.math.multiply(self.lpf_multiplier, val_f)
|
|
|
|
|
- filtered_t = tf.signal.ifft(filtered_f)
|
|
|
|
|
- real_t = tf.cast(filtered_t, dtype=tf.float32)
|
|
|
|
|
- noisy = self.noise_layer.call(real_t, training=True)
|
|
|
|
|
- return noisy
|
|
|
|
|
-
|
|
|
|
|
-
|
|
|
|
|
-class OpticalChannel(layers.Layer):
|
|
|
|
|
- def __init__(self,
|
|
|
|
|
- fs,
|
|
|
|
|
- num_of_samples,
|
|
|
|
|
- dispersion_factor,
|
|
|
|
|
- fiber_length,
|
|
|
|
|
- lpf_cutoff=32e9,
|
|
|
|
|
- rx_stddev=0.01,
|
|
|
|
|
- q_stddev=0.01):
|
|
|
|
|
- """
|
|
|
|
|
- A channel model that simulates chromatic dispersion, non-linear photodiode detection, finite bandwidth of
|
|
|
|
|
- ADC/DAC as well as additive white gaussian noise in optical communication channels.
|
|
|
|
|
-
|
|
|
|
|
- :param fs: Sampling frequency of the simulation in Hz
|
|
|
|
|
- :param num_of_samples: Total number of samples in the input
|
|
|
|
|
- :param dispersion_factor: Dispersion factor in s^2/km
|
|
|
|
|
- :param fiber_length: Length of fiber to model in km
|
|
|
|
|
- :param lpf_cutoff: Cutoff frequency of LPF modelling finite bandwidth in ADC/DAC
|
|
|
|
|
- :param rx_stddev: Standard deviation of receiver noise (due to e.g. TIA circuit)
|
|
|
|
|
- :param q_stddev: Standard deviation of quantization noise at ADC/DAC
|
|
|
|
|
- """
|
|
|
|
|
- super(OpticalChannel, self).__init__()
|
|
|
|
|
-
|
|
|
|
|
- self.noise_layer = layers.GaussianNoise(rx_stddev)
|
|
|
|
|
- self.digitization_layer = DigitizationLayer(fs=fs,
|
|
|
|
|
- num_of_samples=num_of_samples,
|
|
|
|
|
- lpf_cutoff=lpf_cutoff,
|
|
|
|
|
- q_stddev=q_stddev)
|
|
|
|
|
- self.flatten_layer = layers.Flatten()
|
|
|
|
|
-
|
|
|
|
|
- self.fs = fs
|
|
|
|
|
- self.freq = tf.convert_to_tensor(np.fft.fftfreq(num_of_samples, d=1/fs), dtype=tf.complex128)
|
|
|
|
|
- self.multiplier = tf.math.exp(0.5j*dispersion_factor*fiber_length*tf.math.square(2*math.pi*self.freq))
|
|
|
|
|
-
|
|
|
|
|
- def call(self, inputs, **kwargs):
|
|
|
|
|
- # DAC LPF and noise
|
|
|
|
|
- dac_out = self.digitization_layer(inputs)
|
|
|
|
|
-
|
|
|
|
|
- # Chromatic Dispersion
|
|
|
|
|
- complex_val = tf.cast(dac_out, dtype=tf.complex128)
|
|
|
|
|
- val_f = tf.signal.fft(complex_val)
|
|
|
|
|
- disp_f = tf.math.multiply(val_f, self.multiplier)
|
|
|
|
|
- disp_t = tf.signal.ifft(disp_f)
|
|
|
|
|
-
|
|
|
|
|
- # Squared-Law Detection
|
|
|
|
|
- pd_out = tf.square(tf.abs(disp_t))
|
|
|
|
|
-
|
|
|
|
|
- # Casting back to floatx
|
|
|
|
|
- real_val = tf.cast(pd_out, dtype=tf.float32)
|
|
|
|
|
-
|
|
|
|
|
- # Adding photo-diode receiver noise
|
|
|
|
|
- rx_signal = self.noise_layer.call(real_val, training=True)
|
|
|
|
|
-
|
|
|
|
|
- # ADC LPF and noise
|
|
|
|
|
- adc_out = self.digitization_layer(rx_signal)
|
|
|
|
|
-
|
|
|
|
|
- return adc_out
|
|
|
|
|
|
|
+from models.custom_layers import ExtractCentralMessage, OpticalChannel, DigitizationLayer, SymbolsToBits
|
|
|
|
|
|
|
|
|
|
|
|
|
class EndToEndAutoencoder(tf.keras.Model):
|
|
class EndToEndAutoencoder(tf.keras.Model):
|
|
@@ -145,7 +16,8 @@ class EndToEndAutoencoder(tf.keras.Model):
|
|
|
cardinality,
|
|
cardinality,
|
|
|
samples_per_symbol,
|
|
samples_per_symbol,
|
|
|
messages_per_block,
|
|
messages_per_block,
|
|
|
- channel):
|
|
|
|
|
|
|
+ channel,
|
|
|
|
|
+ custom_loss_fn=False):
|
|
|
"""
|
|
"""
|
|
|
The autoencoder that aims to find a encoding of the input messages. It should be noted that a "block" consists
|
|
The autoencoder that aims to find a encoding of the input messages. It should be noted that a "block" consists
|
|
|
of multiple "messages" to introduce memory into the simulation as this is essential for modelling inter-symbol
|
|
of multiple "messages" to introduce memory into the simulation as this is essential for modelling inter-symbol
|
|
@@ -160,38 +32,158 @@ class EndToEndAutoencoder(tf.keras.Model):
|
|
|
|
|
|
|
|
# Labelled M in paper
|
|
# Labelled M in paper
|
|
|
self.cardinality = cardinality
|
|
self.cardinality = cardinality
|
|
|
|
|
+ self.bits_per_symbol = int(math.log(self.cardinality, 2))
|
|
|
|
|
+
|
|
|
# Labelled n in paper
|
|
# Labelled n in paper
|
|
|
self.samples_per_symbol = samples_per_symbol
|
|
self.samples_per_symbol = samples_per_symbol
|
|
|
- # Labelled N in paper
|
|
|
|
|
|
|
+
|
|
|
|
|
+ # Labelled N in paper - conditional +=1 to ensure odd value
|
|
|
if messages_per_block % 2 == 0:
|
|
if messages_per_block % 2 == 0:
|
|
|
messages_per_block += 1
|
|
messages_per_block += 1
|
|
|
self.messages_per_block = messages_per_block
|
|
self.messages_per_block = messages_per_block
|
|
|
|
|
+
|
|
|
# Channel Model Layer
|
|
# Channel Model Layer
|
|
|
if isinstance(channel, layers.Layer):
|
|
if isinstance(channel, layers.Layer):
|
|
|
self.channel = tf.keras.Sequential([
|
|
self.channel = tf.keras.Sequential([
|
|
|
layers.Flatten(),
|
|
layers.Flatten(),
|
|
|
channel,
|
|
channel,
|
|
|
ExtractCentralMessage(self.messages_per_block, self.samples_per_symbol)
|
|
ExtractCentralMessage(self.messages_per_block, self.samples_per_symbol)
|
|
|
- ])
|
|
|
|
|
|
|
+ ], name="channel_model")
|
|
|
else:
|
|
else:
|
|
|
- raise TypeError("Channel must be a subclass of keras.layers.layer!")
|
|
|
|
|
|
|
+ raise TypeError("Channel must be a subclass of \"tensorflow.keras.layers.layer\"!")
|
|
|
|
|
+
|
|
|
|
|
+ # Boolean identifying if bit mapping is to be learnt
|
|
|
|
|
+ self.custom_loss_fn = custom_loss_fn
|
|
|
|
|
+
|
|
|
|
|
+ # other parameters/metrics
|
|
|
|
|
+ self.symbol_error_rate = None
|
|
|
|
|
+ self.bit_error_rate = None
|
|
|
|
|
+ self.snr = 20 * math.log(0.5 / channel.rx_stddev, 10)
|
|
|
|
|
+
|
|
|
|
|
+ # Model Hyper-parameters
|
|
|
|
|
+ leaky_relu_alpha = 0
|
|
|
|
|
+ relu_clip_val = 1.0
|
|
|
|
|
|
|
|
# Encoding Neural Network
|
|
# Encoding Neural Network
|
|
|
self.encoder = tf.keras.Sequential([
|
|
self.encoder = tf.keras.Sequential([
|
|
|
layers.Input(shape=(self.messages_per_block, self.cardinality)),
|
|
layers.Input(shape=(self.messages_per_block, self.cardinality)),
|
|
|
- layers.Dense(2 * self.cardinality, activation='relu'),
|
|
|
|
|
- layers.Dense(2 * self.cardinality, activation='relu'),
|
|
|
|
|
- layers.Dense(self.samples_per_symbol),
|
|
|
|
|
- layers.ReLU(max_value=1.0)
|
|
|
|
|
- ])
|
|
|
|
|
|
|
+ layers.TimeDistributed(layers.Dense(2 * self.cardinality)),
|
|
|
|
|
+ layers.TimeDistributed(layers.LeakyReLU(alpha=leaky_relu_alpha)),
|
|
|
|
|
+ layers.TimeDistributed(layers.Dense(2 * self.cardinality)),
|
|
|
|
|
+ layers.TimeDistributed(layers.LeakyReLU(alpha=leaky_relu_alpha)),
|
|
|
|
|
+ layers.TimeDistributed(layers.Dense(self.samples_per_symbol, activation='sigmoid')),
|
|
|
|
|
+ # layers.TimeDistributed(layers.Dense(self.samples_per_symbol)),
|
|
|
|
|
+ # layers.TimeDistributed(layers.ReLU(max_value=relu_clip_val))
|
|
|
|
|
+ ], name="encoding_model")
|
|
|
|
|
|
|
|
# Decoding Neural Network
|
|
# Decoding Neural Network
|
|
|
self.decoder = tf.keras.Sequential([
|
|
self.decoder = tf.keras.Sequential([
|
|
|
- layers.Dense(self.samples_per_symbol, activation='relu'),
|
|
|
|
|
- layers.Dense(2 * self.cardinality, activation='relu'),
|
|
|
|
|
- layers.Dense(2 * self.cardinality, activation='relu'),
|
|
|
|
|
|
|
+ layers.Dense(2 * self.cardinality),
|
|
|
|
|
+ layers.LeakyReLU(alpha=leaky_relu_alpha),
|
|
|
|
|
+ layers.Dense(2 * self.cardinality),
|
|
|
|
|
+ layers.LeakyReLU(alpha=leaky_relu_alpha),
|
|
|
layers.Dense(self.cardinality, activation='softmax')
|
|
layers.Dense(self.cardinality, activation='softmax')
|
|
|
- ])
|
|
|
|
|
|
|
+ ], name="decoding_model")
|
|
|
|
|
+
|
|
|
|
|
+ def save_end_to_end(self):
|
|
|
|
|
+ # extract all params and save
|
|
|
|
|
+
|
|
|
|
|
+ params = {"fs": self.channel.layers[1].fs,
|
|
|
|
|
+ "cardinality": self.cardinality,
|
|
|
|
|
+ "samples_per_symbol": self.samples_per_symbol,
|
|
|
|
|
+ "messages_per_block": self.messages_per_block,
|
|
|
|
|
+ "dispersion_factor": self.channel.layers[1].dispersion_factor,
|
|
|
|
|
+ "fiber_length": float(self.channel.layers[1].fiber_length),
|
|
|
|
|
+ "fiber_length_stddev": float(self.channel.layers[1].fiber_length_stddev),
|
|
|
|
|
+ "lpf_cutoff": self.channel.layers[1].lpf_cutoff,
|
|
|
|
|
+ "rx_stddev": self.channel.layers[1].rx_stddev,
|
|
|
|
|
+ "sig_avg": self.channel.layers[1].sig_avg,
|
|
|
|
|
+ "enob": self.channel.layers[1].enob,
|
|
|
|
|
+ "custom_loss_fn": self.custom_loss_fn
|
|
|
|
|
+ }
|
|
|
|
|
+ dir_str = os.path.join("exports", dt.utcnow().strftime("%Y%m%d-%H%M%S"))
|
|
|
|
|
+
|
|
|
|
|
+ if not os.path.exists(dir_str):
|
|
|
|
|
+ os.makedirs(dir_str)
|
|
|
|
|
+
|
|
|
|
|
+ with open(os.path.join(dir_str, 'params.json'), 'w') as outfile:
|
|
|
|
|
+ json.dump(params, outfile)
|
|
|
|
|
+
|
|
|
|
|
+ ################################################################################################################
|
|
|
|
|
+ # This section exports the weights of the encoder formatted using python variable instantiation syntax
|
|
|
|
|
+ ################################################################################################################
|
|
|
|
|
+
|
|
|
|
|
+ enc_weights, dec_weights = self.extract_weights()
|
|
|
|
|
+
|
|
|
|
|
+ enc_weights = [x.tolist() for x in enc_weights]
|
|
|
|
|
+ dec_weights = [x.tolist() for x in dec_weights]
|
|
|
|
|
+
|
|
|
|
|
+ enc_w = enc_weights[::2]
|
|
|
|
|
+ enc_b = enc_weights[1::2]
|
|
|
|
|
+
|
|
|
|
|
+ dec_w = dec_weights[::2]
|
|
|
|
|
+ dec_b = dec_weights[1::2]
|
|
|
|
|
+
|
|
|
|
|
+ with open(os.path.join(dir_str, 'enc_weights.py'), 'w') as outfile:
|
|
|
|
|
+ outfile.write("enc_weights = ")
|
|
|
|
|
+ outfile.write(str(enc_w))
|
|
|
|
|
+ outfile.write("\n\nenc_bias = ")
|
|
|
|
|
+ outfile.write(str(enc_b))
|
|
|
|
|
+
|
|
|
|
|
+ with open(os.path.join(dir_str, 'dec_weights.py'), 'w') as outfile:
|
|
|
|
|
+ outfile.write("dec_weights = ")
|
|
|
|
|
+ outfile.write(str(dec_w))
|
|
|
|
|
+ outfile.write("\n\ndec_bias = ")
|
|
|
|
|
+ outfile.write(str(dec_b))
|
|
|
|
|
+
|
|
|
|
|
+ ################################################################################################################
|
|
|
|
|
+
|
|
|
|
|
+ self.encoder.save(os.path.join(dir_str, 'encoder'))
|
|
|
|
|
+ self.decoder.save(os.path.join(dir_str, 'decoder'))
|
|
|
|
|
+
|
|
|
|
|
+ def extract_weights(self):
|
|
|
|
|
+ enc_weights = self.encoder.get_weights()
|
|
|
|
|
+ dec_weights = self.encoder.get_weights()
|
|
|
|
|
+
|
|
|
|
|
+ return enc_weights, dec_weights
|
|
|
|
|
+
|
|
|
|
|
+ def encode_stream(self, x):
|
|
|
|
|
+ enc_weights, _ = self.extract_weights()
|
|
|
|
|
+
|
|
|
|
|
+ for i in range(len(enc_weights) // 2):
|
|
|
|
|
+ x = np.matmul(x, enc_weights[2 * i]) + enc_weights[2 * i + 1]
|
|
|
|
|
+
|
|
|
|
|
+ if i == len(enc_weights) // 2 - 1:
|
|
|
|
|
+ x = tf.keras.activations.sigmoid(x).numpy()
|
|
|
|
|
+ else:
|
|
|
|
|
+ x = tf.keras.activations.relu(x).numpy()
|
|
|
|
|
+
|
|
|
|
|
+ return x
|
|
|
|
|
+
|
|
|
|
|
+ def decode_stream(self, x):
|
|
|
|
|
+ _, dec_weights = self.extract_weights()
|
|
|
|
|
+
|
|
|
|
|
+ for i in range(len(dec_weights) // 2):
|
|
|
|
|
+ x = np.matmul(x, dec_weights[2 * i]) + dec_weights[2 * i + 1]
|
|
|
|
|
+
|
|
|
|
|
+ if i == len(dec_weights) // 2 - 1:
|
|
|
|
|
+ x = tf.keras.activations.softmax(x).numpy()
|
|
|
|
|
+ else:
|
|
|
|
|
+ x = tf.keras.activations.relu(x).numpy()
|
|
|
|
|
+
|
|
|
|
|
+ return x
|
|
|
|
|
+
|
|
|
|
|
+ def cost(self, y_true, y_pred):
|
|
|
|
|
+ symbol_cost = losses.CategoricalCrossentropy()(y_true, y_pred)
|
|
|
|
|
+
|
|
|
|
|
+ y_bits_true = SymbolsToBits(self.cardinality)(y_true)
|
|
|
|
|
+ y_bits_pred = SymbolsToBits(self.cardinality)(y_pred)
|
|
|
|
|
+
|
|
|
|
|
+ bit_cost = losses.BinaryCrossentropy()(y_bits_true, y_bits_pred)
|
|
|
|
|
+
|
|
|
|
|
+ a = 1
|
|
|
|
|
+
|
|
|
|
|
+ return symbol_cost + a * bit_cost
|
|
|
|
|
|
|
|
def generate_random_inputs(self, num_of_blocks, return_vals=False):
|
|
def generate_random_inputs(self, num_of_blocks, return_vals=False):
|
|
|
"""
|
|
"""
|
|
@@ -201,15 +193,17 @@ class EndToEndAutoencoder(tf.keras.Model):
|
|
|
consecutively to model ISI. The central message in a block is returned as the label for training.
|
|
consecutively to model ISI. The central message in a block is returned as the label for training.
|
|
|
:param return_vals: If true, the raw decimal values of the input sequence will be returned
|
|
:param return_vals: If true, the raw decimal values of the input sequence will be returned
|
|
|
"""
|
|
"""
|
|
|
- rand_int = np.random.randint(self.cardinality, size=(num_of_blocks * self.messages_per_block, 1))
|
|
|
|
|
|
|
|
|
|
cat = [np.arange(self.cardinality)]
|
|
cat = [np.arange(self.cardinality)]
|
|
|
enc = OneHotEncoder(handle_unknown='ignore', sparse=False, categories=cat)
|
|
enc = OneHotEncoder(handle_unknown='ignore', sparse=False, categories=cat)
|
|
|
|
|
|
|
|
|
|
+ mid_idx = int((self.messages_per_block - 1) / 2)
|
|
|
|
|
+
|
|
|
|
|
+ rand_int = np.random.randint(self.cardinality, size=(num_of_blocks * self.messages_per_block, 1))
|
|
|
|
|
+
|
|
|
out = enc.fit_transform(rand_int)
|
|
out = enc.fit_transform(rand_int)
|
|
|
- out_arr = np.reshape(out, (num_of_blocks, self.messages_per_block, self.cardinality))
|
|
|
|
|
|
|
|
|
|
- mid_idx = int((self.messages_per_block-1)/2)
|
|
|
|
|
|
|
+ out_arr = np.reshape(out, (num_of_blocks, self.messages_per_block, self.cardinality))
|
|
|
|
|
|
|
|
if return_vals:
|
|
if return_vals:
|
|
|
out_val = np.reshape(rand_int, (num_of_blocks, self.messages_per_block, 1))
|
|
out_val = np.reshape(rand_int, (num_of_blocks, self.messages_per_block, 1))
|
|
@@ -217,7 +211,7 @@ class EndToEndAutoencoder(tf.keras.Model):
|
|
|
|
|
|
|
|
return out_arr, out_arr[:, mid_idx, :]
|
|
return out_arr, out_arr[:, mid_idx, :]
|
|
|
|
|
|
|
|
- def train(self, num_of_blocks=1e6, batch_size=None, train_size=0.8, lr=1e-3):
|
|
|
|
|
|
|
+ def train(self, num_of_blocks=1e6, epochs=50, batch_size=None, train_size=0.8, lr=1e-3):
|
|
|
"""
|
|
"""
|
|
|
Method to train the autoencoder. Further configuration to the loss function, optimizer etc. can be made in here.
|
|
Method to train the autoencoder. Further configuration to the loss function, optimizer etc. can be made in here.
|
|
|
|
|
|
|
@@ -226,37 +220,110 @@ class EndToEndAutoencoder(tf.keras.Model):
|
|
|
:param train_size: Float less than 1 representing the proportion of the dataset to use for training
|
|
:param train_size: Float less than 1 representing the proportion of the dataset to use for training
|
|
|
:param lr: The learning rate of the optimizer. Defines how quickly the algorithm converges
|
|
:param lr: The learning rate of the optimizer. Defines how quickly the algorithm converges
|
|
|
"""
|
|
"""
|
|
|
- X_train, y_train = self.generate_random_inputs(int(num_of_blocks*train_size))
|
|
|
|
|
- X_test, y_test = self.generate_random_inputs(int(num_of_blocks*(1-train_size)))
|
|
|
|
|
|
|
+ X_train, y_train = self.generate_random_inputs(int(num_of_blocks * train_size))
|
|
|
|
|
+ X_test, y_test = self.generate_random_inputs(int(num_of_blocks * (1 - train_size)))
|
|
|
|
|
|
|
|
opt = tf.keras.optimizers.Adam(learning_rate=lr)
|
|
opt = tf.keras.optimizers.Adam(learning_rate=lr)
|
|
|
|
|
|
|
|
|
|
+ if self.custom_loss_fn:
|
|
|
|
|
+ loss_fn = self.cost
|
|
|
|
|
+ else:
|
|
|
|
|
+ loss_fn = losses.CategoricalCrossentropy()
|
|
|
|
|
+
|
|
|
|
|
+ callback = tf.keras.callbacks.EarlyStopping(monitor='loss', patience=3)
|
|
|
|
|
+
|
|
|
self.compile(optimizer=opt,
|
|
self.compile(optimizer=opt,
|
|
|
- loss=losses.BinaryCrossentropy(),
|
|
|
|
|
|
|
+ loss=loss_fn,
|
|
|
metrics=['accuracy'],
|
|
metrics=['accuracy'],
|
|
|
loss_weights=None,
|
|
loss_weights=None,
|
|
|
weighted_metrics=None,
|
|
weighted_metrics=None,
|
|
|
run_eagerly=False
|
|
run_eagerly=False
|
|
|
)
|
|
)
|
|
|
|
|
|
|
|
- self.fit(x=X_train,
|
|
|
|
|
- y=y_train,
|
|
|
|
|
- batch_size=batch_size,
|
|
|
|
|
- epochs=1,
|
|
|
|
|
- shuffle=True,
|
|
|
|
|
- validation_data=(X_test, y_test)
|
|
|
|
|
- )
|
|
|
|
|
|
|
+ history = self.fit(x=X_train,
|
|
|
|
|
+ y=y_train,
|
|
|
|
|
+ batch_size=batch_size,
|
|
|
|
|
+ epochs=epochs,
|
|
|
|
|
+ callbacks=[callback],
|
|
|
|
|
+ shuffle=True,
|
|
|
|
|
+ validation_data=(X_test, y_test)
|
|
|
|
|
+ )
|
|
|
|
|
+
|
|
|
|
|
+ if len(history.history['loss']) == epochs:
|
|
|
|
|
+ print("The model trained for the maximum number of epochs and may not have converged to a good solution. "
|
|
|
|
|
+ "Setting a higher epoch number and retraining is recommended")
|
|
|
|
|
+
|
|
|
|
|
+ def test(self, num_of_blocks=1e4, length_plot=False, plt_show=True):
|
|
|
|
|
+ X_test, y_test = self.generate_random_inputs(int(num_of_blocks))
|
|
|
|
|
+
|
|
|
|
|
+ y_out = self.call(X_test)
|
|
|
|
|
+
|
|
|
|
|
+ y_pred = tf.argmax(y_out, axis=1)
|
|
|
|
|
+ y_true = tf.argmax(y_test, axis=1)
|
|
|
|
|
+
|
|
|
|
|
+ self.symbol_error_rate = 1 - accuracy_score(y_true, y_pred)
|
|
|
|
|
+
|
|
|
|
|
+ bits_pred = SymbolsToBits(self.cardinality)(tf.one_hot(y_pred, self.cardinality)).numpy().flatten()
|
|
|
|
|
+ bits_true = SymbolsToBits(self.cardinality)(y_test).numpy().flatten()
|
|
|
|
|
+
|
|
|
|
|
+ self.bit_error_rate = 1 - accuracy_score(bits_true, bits_pred)
|
|
|
|
|
+
|
|
|
|
|
+ if (length_plot):
|
|
|
|
|
+
|
|
|
|
|
+ lengths = np.linspace(0, 70, 50)
|
|
|
|
|
+
|
|
|
|
|
+ ber_l = []
|
|
|
|
|
+
|
|
|
|
|
+ for l in lengths:
|
|
|
|
|
+ tx_channel = OpticalChannel(fs=self.channel.layers[1].fs,
|
|
|
|
|
+ num_of_samples=self.channel.layers[1].num_of_samples,
|
|
|
|
|
+ dispersion_factor=self.channel.layers[1].dispersion_factor,
|
|
|
|
|
+ fiber_length=l,
|
|
|
|
|
+ lpf_cutoff=self.channel.layers[1].lpf_cutoff,
|
|
|
|
|
+ rx_stddev=self.channel.layers[1].rx_stddev,
|
|
|
|
|
+ sig_avg=self.channel.layers[1].sig_avg,
|
|
|
|
|
+ enob=self.channel.layers[1].enob)
|
|
|
|
|
+
|
|
|
|
|
+ test_channel = tf.keras.Sequential([
|
|
|
|
|
+ layers.Flatten(),
|
|
|
|
|
+ tx_channel,
|
|
|
|
|
+ ExtractCentralMessage(self.messages_per_block, self.samples_per_symbol)
|
|
|
|
|
+ ], name="test channel (variable length)")
|
|
|
|
|
+
|
|
|
|
|
+ X_test_l, y_test_l = self.generate_random_inputs(int(num_of_blocks))
|
|
|
|
|
+
|
|
|
|
|
+ y_out_l = self.decoder(test_channel(self.encoder(X_test_l)))
|
|
|
|
|
+
|
|
|
|
|
+ y_pred_l = tf.argmax(y_out_l, axis=1)
|
|
|
|
|
+ # y_true_l = tf.argmax(y_test_l, axis=1)
|
|
|
|
|
+
|
|
|
|
|
+ bits_pred_l = SymbolsToBits(self.cardinality)(tf.one_hot(y_pred_l, self.cardinality)).numpy().flatten()
|
|
|
|
|
+ bits_true_l = SymbolsToBits(self.cardinality)(y_test_l).numpy().flatten()
|
|
|
|
|
+
|
|
|
|
|
+ bit_error_rate_l = 1 - accuracy_score(bits_true_l, bits_pred_l)
|
|
|
|
|
+ ber_l.append(bit_error_rate_l)
|
|
|
|
|
+
|
|
|
|
|
+ plt.plot(lengths, ber_l)
|
|
|
|
|
+ plt.yscale('log')
|
|
|
|
|
+ if plt_show:
|
|
|
|
|
+ plt.show()
|
|
|
|
|
+
|
|
|
|
|
+ print("SYMBOL ERROR RATE: {}".format(self.symbol_error_rate))
|
|
|
|
|
+ print("BIT ERROR RATE: {}".format(self.bit_error_rate))
|
|
|
|
|
+
|
|
|
|
|
+ pass
|
|
|
|
|
|
|
|
def view_encoder(self):
|
|
def view_encoder(self):
|
|
|
'''
|
|
'''
|
|
|
- A method that views the learnt encoder for each distint message. This is displayed as a plot with asubplot for
|
|
|
|
|
- each image.
|
|
|
|
|
|
|
+ A method that views the learnt encoder for each distint message. This is displayed as a plot with a subplot for
|
|
|
|
|
+ each message/symbol.
|
|
|
'''
|
|
'''
|
|
|
|
|
+
|
|
|
|
|
+ mid_idx = int((self.messages_per_block - 1) / 2)
|
|
|
|
|
+
|
|
|
# Generate inputs for encoder
|
|
# Generate inputs for encoder
|
|
|
messages = np.zeros((self.cardinality, self.messages_per_block, self.cardinality))
|
|
messages = np.zeros((self.cardinality, self.messages_per_block, self.cardinality))
|
|
|
|
|
|
|
|
- mid_idx = int((self.messages_per_block-1)/2)
|
|
|
|
|
-
|
|
|
|
|
idx = 0
|
|
idx = 0
|
|
|
for msg in messages:
|
|
for msg in messages:
|
|
|
msg[mid_idx, idx] = 1
|
|
msg[mid_idx, idx] = 1
|
|
@@ -268,23 +335,23 @@ class EndToEndAutoencoder(tf.keras.Model):
|
|
|
|
|
|
|
|
# Compute subplot grid layout
|
|
# Compute subplot grid layout
|
|
|
i = 0
|
|
i = 0
|
|
|
- while 2**i < self.cardinality**0.5:
|
|
|
|
|
|
|
+ while 2 ** i < self.cardinality ** 0.5:
|
|
|
i += 1
|
|
i += 1
|
|
|
|
|
|
|
|
- num_x = int(2**i)
|
|
|
|
|
|
|
+ num_x = int(2 ** i)
|
|
|
num_y = int(self.cardinality / num_x)
|
|
num_y = int(self.cardinality / num_x)
|
|
|
|
|
|
|
|
# Plot all symbols
|
|
# Plot all symbols
|
|
|
- fig, axs = plt.subplots(num_y, num_x, figsize=(2.5*num_x, 2*num_y))
|
|
|
|
|
|
|
+ fig, axs = plt.subplots(num_y, num_x, figsize=(2.5 * num_x, 2 * num_y))
|
|
|
|
|
|
|
|
t = np.arange(self.samples_per_symbol)
|
|
t = np.arange(self.samples_per_symbol)
|
|
|
if isinstance(self.channel.layers[1], OpticalChannel):
|
|
if isinstance(self.channel.layers[1], OpticalChannel):
|
|
|
- t = t/self.channel.layers[1].fs
|
|
|
|
|
|
|
+ t = t / self.channel.layers[1].fs
|
|
|
|
|
|
|
|
sym_idx = 0
|
|
sym_idx = 0
|
|
|
for y in range(num_y):
|
|
for y in range(num_y):
|
|
|
for x in range(num_x):
|
|
for x in range(num_x):
|
|
|
- axs[y, x].plot(t, enc_messages[sym_idx], 'x')
|
|
|
|
|
|
|
+ axs[y, x].plot(t, enc_messages[sym_idx].numpy().flatten(), 'x')
|
|
|
axs[y, x].set_title('Symbol {}'.format(str(sym_idx)))
|
|
axs[y, x].set_title('Symbol {}'.format(str(sym_idx)))
|
|
|
sym_idx += 1
|
|
sym_idx += 1
|
|
|
|
|
|
|
@@ -308,33 +375,40 @@ class EndToEndAutoencoder(tf.keras.Model):
|
|
|
# Encode and flatten the messages
|
|
# Encode and flatten the messages
|
|
|
enc = self.encoder(inp)
|
|
enc = self.encoder(inp)
|
|
|
flat_enc = layers.Flatten()(enc)
|
|
flat_enc = layers.Flatten()(enc)
|
|
|
|
|
+ chan_out = self.channel.layers[1](flat_enc)
|
|
|
|
|
|
|
|
# Instantiate LPF layer
|
|
# Instantiate LPF layer
|
|
|
lpf = DigitizationLayer(fs=self.channel.layers[1].fs,
|
|
lpf = DigitizationLayer(fs=self.channel.layers[1].fs,
|
|
|
- num_of_samples=self.messages_per_block*self.samples_per_symbol,
|
|
|
|
|
- q_stddev=0)
|
|
|
|
|
|
|
+ num_of_samples=self.messages_per_block * self.samples_per_symbol,
|
|
|
|
|
+ sig_avg=0)
|
|
|
|
|
|
|
|
# Apply LPF
|
|
# Apply LPF
|
|
|
lpf_out = lpf(flat_enc)
|
|
lpf_out = lpf(flat_enc)
|
|
|
|
|
|
|
|
|
|
+ a = np.fft.fft(lpf_out.numpy()).flatten()
|
|
|
|
|
+ f = np.fft.fftfreq(a.shape[-1]).flatten()
|
|
|
|
|
+
|
|
|
|
|
+ plt.plot(f, a)
|
|
|
|
|
+ plt.show()
|
|
|
|
|
+
|
|
|
# Time axis
|
|
# Time axis
|
|
|
- t = np.arange(self.messages_per_block*self.samples_per_symbol)
|
|
|
|
|
|
|
+ t = np.arange(self.messages_per_block * self.samples_per_symbol)
|
|
|
if isinstance(self.channel.layers[1], OpticalChannel):
|
|
if isinstance(self.channel.layers[1], OpticalChannel):
|
|
|
t = t / self.channel.layers[1].fs
|
|
t = t / self.channel.layers[1].fs
|
|
|
|
|
|
|
|
# Plot the concatenated symbols before and after LPF
|
|
# Plot the concatenated symbols before and after LPF
|
|
|
- plt.figure(figsize=(2*self.messages_per_block, 6))
|
|
|
|
|
|
|
+ plt.figure(figsize=(2 * self.messages_per_block, 6))
|
|
|
|
|
|
|
|
for i in range(1, self.messages_per_block):
|
|
for i in range(1, self.messages_per_block):
|
|
|
- plt.axvline(x=t[i*self.samples_per_symbol], color='black')
|
|
|
|
|
|
|
+ plt.axvline(x=t[i * self.samples_per_symbol], color='black')
|
|
|
|
|
|
|
|
plt.plot(t, flat_enc.numpy().T, 'x')
|
|
plt.plot(t, flat_enc.numpy().T, 'x')
|
|
|
plt.plot(t, lpf_out.numpy().T)
|
|
plt.plot(t, lpf_out.numpy().T)
|
|
|
|
|
+ plt.plot(t, chan_out.numpy().flatten())
|
|
|
plt.ylim((0, 1))
|
|
plt.ylim((0, 1))
|
|
|
plt.xlim((t.min(), t.max()))
|
|
plt.xlim((t.min(), t.max()))
|
|
|
plt.title(str(val[0, :, 0]))
|
|
plt.title(str(val[0, :, 0]))
|
|
|
plt.show()
|
|
plt.show()
|
|
|
- pass
|
|
|
|
|
|
|
|
|
|
def call(self, inputs, training=None, mask=None):
|
|
def call(self, inputs, training=None, mask=None):
|
|
|
tx = self.encoder(inputs)
|
|
tx = self.encoder(inputs)
|
|
@@ -343,27 +417,152 @@ class EndToEndAutoencoder(tf.keras.Model):
|
|
|
return outputs
|
|
return outputs
|
|
|
|
|
|
|
|
|
|
|
|
|
-if __name__ == '__main__':
|
|
|
|
|
-
|
|
|
|
|
- SAMPLING_FREQUENCY = 336e9
|
|
|
|
|
- CARDINALITY = 32
|
|
|
|
|
- SAMPLES_PER_SYMBOL = 24
|
|
|
|
|
- MESSAGES_PER_BLOCK = 9
|
|
|
|
|
- DISPERSION_FACTOR = -21.7 * 1e-24
|
|
|
|
|
- FIBER_LENGTH = 50
|
|
|
|
|
|
|
+def load_model(model_name=None):
|
|
|
|
|
+ if model_name is None:
|
|
|
|
|
+ models = os.listdir("exports")
|
|
|
|
|
+ if not models:
|
|
|
|
|
+ raise Exception("Unable to find a trained model. Please first train and save a model.")
|
|
|
|
|
+ model_name = models[-1]
|
|
|
|
|
+
|
|
|
|
|
+ param_file_path = os.path.join("exports", model_name, "params.json")
|
|
|
|
|
+
|
|
|
|
|
+ if not os.path.isfile(param_file_path):
|
|
|
|
|
+ raise Exception("Invalid File Name/Directory")
|
|
|
|
|
+ else:
|
|
|
|
|
+ with open(param_file_path, 'r') as param_file:
|
|
|
|
|
+ params = json.load(param_file)
|
|
|
|
|
+
|
|
|
|
|
+ optical_channel = OpticalChannel(fs=params["fs"],
|
|
|
|
|
+ num_of_samples=params["messages_per_block"] * params["samples_per_symbol"],
|
|
|
|
|
+ dispersion_factor=params["dispersion_factor"],
|
|
|
|
|
+ fiber_length=params["fiber_length"],
|
|
|
|
|
+ fiber_length_stddev=params["fiber_length_stddev"],
|
|
|
|
|
+ lpf_cutoff=params["lpf_cutoff"],
|
|
|
|
|
+ rx_stddev=params["rx_stddev"],
|
|
|
|
|
+ sig_avg=params["sig_avg"],
|
|
|
|
|
+ enob=params["enob"])
|
|
|
|
|
+
|
|
|
|
|
+ ae_model = EndToEndAutoencoder(cardinality=params["cardinality"],
|
|
|
|
|
+ samples_per_symbol=params["samples_per_symbol"],
|
|
|
|
|
+ messages_per_block=params["messages_per_block"],
|
|
|
|
|
+ channel=optical_channel,
|
|
|
|
|
+ custom_loss_fn=params["custom_loss_fn"])
|
|
|
|
|
+
|
|
|
|
|
+ ae_model.encoder = tf.keras.models.load_model(os.path.join("exports", model_name, "encoder"))
|
|
|
|
|
+ ae_model.decoder = tf.keras.models.load_model(os.path.join("exports", model_name, "decoder"))
|
|
|
|
|
+
|
|
|
|
|
+ return ae_model, params
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+if __name__ == 'asd':
|
|
|
|
|
+
|
|
|
|
|
+ params = {"fs": 336e9,
|
|
|
|
|
+ "cardinality": 32,
|
|
|
|
|
+ "samples_per_symbol": 32,
|
|
|
|
|
+ "messages_per_block": 9,
|
|
|
|
|
+ "dispersion_factor": (-21.7 * 1e-24),
|
|
|
|
|
+ "fiber_length": 50,
|
|
|
|
|
+ "fiber_length_stddev": 1,
|
|
|
|
|
+ "lpf_cutoff": 32e9,
|
|
|
|
|
+ "rx_stddev": 0.01,
|
|
|
|
|
+ "sig_avg": 0.5,
|
|
|
|
|
+ "enob": 8,
|
|
|
|
|
+ "custom_loss_fn": True
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ lengths = np.linspace(40, 100, 7)
|
|
|
|
|
+ ber = []
|
|
|
|
|
+ for len_ in lengths:
|
|
|
|
|
+ optical_channel = OpticalChannel(fs=params["fs"],
|
|
|
|
|
+ num_of_samples=params["messages_per_block"] * params["samples_per_symbol"],
|
|
|
|
|
+ dispersion_factor=params["dispersion_factor"],
|
|
|
|
|
+ fiber_length=len_,
|
|
|
|
|
+ fiber_length_stddev=params["fiber_length_stddev"],
|
|
|
|
|
+ lpf_cutoff=params["lpf_cutoff"],
|
|
|
|
|
+ rx_stddev=0,
|
|
|
|
|
+ sig_avg=0,
|
|
|
|
|
+ enob=params["enob"])
|
|
|
|
|
+
|
|
|
|
|
+ ae_model = EndToEndAutoencoder(cardinality=params["cardinality"],
|
|
|
|
|
+ samples_per_symbol=params["samples_per_symbol"],
|
|
|
|
|
+ messages_per_block=params["messages_per_block"],
|
|
|
|
|
+ channel=optical_channel,
|
|
|
|
|
+ custom_loss_fn=params["custom_loss_fn"])
|
|
|
|
|
+ ae_model.train(num_of_blocks=1e5)
|
|
|
|
|
+ ae_model.test()
|
|
|
|
|
+ ber.append(ae_model.bit_error_rate)
|
|
|
|
|
+
|
|
|
|
|
+ plt.plot(lengths, ber)
|
|
|
|
|
+ plt.title("Bit Error Rate at different trained lengths")
|
|
|
|
|
+ plt.yscale('log')
|
|
|
|
|
+ plt.xlabel("Fiber Length / km")
|
|
|
|
|
+ plt.ylabel("Bit Error Rate")
|
|
|
|
|
+ plt.show()
|
|
|
|
|
+ pass
|
|
|
|
|
|
|
|
- optical_channel = OpticalChannel(fs=SAMPLING_FREQUENCY,
|
|
|
|
|
- num_of_samples=MESSAGES_PER_BLOCK*SAMPLES_PER_SYMBOL,
|
|
|
|
|
- dispersion_factor=DISPERSION_FACTOR,
|
|
|
|
|
- fiber_length=FIBER_LENGTH)
|
|
|
|
|
|
|
+if __name__ == '__main__':
|
|
|
|
|
|
|
|
- ae_model = EndToEndAutoencoder(cardinality=CARDINALITY,
|
|
|
|
|
- samples_per_symbol=SAMPLES_PER_SYMBOL,
|
|
|
|
|
- messages_per_block=MESSAGES_PER_BLOCK,
|
|
|
|
|
- channel=optical_channel)
|
|
|
|
|
|
|
+ params = {"fs": 336e9,
|
|
|
|
|
+ "cardinality": 32,
|
|
|
|
|
+ "samples_per_symbol": 32,
|
|
|
|
|
+ "messages_per_block": 9,
|
|
|
|
|
+ "dispersion_factor": (-21.7 * 1e-24),
|
|
|
|
|
+ "fiber_length": 50,
|
|
|
|
|
+ "fiber_length_stddev": 1,
|
|
|
|
|
+ "lpf_cutoff": 32e9,
|
|
|
|
|
+ "rx_stddev": 0.01,
|
|
|
|
|
+ "sig_avg": 0.5,
|
|
|
|
|
+ "enob": 8,
|
|
|
|
|
+ "custom_loss_fn": True
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ force_training = False
|
|
|
|
|
+
|
|
|
|
|
+ model_save_name = "20210317-124015"
|
|
|
|
|
+ param_file_path = os.path.join("exports", model_save_name, "params.json")
|
|
|
|
|
+
|
|
|
|
|
+ if os.path.isfile(param_file_path) and not force_training:
|
|
|
|
|
+ print("Importing model {}".format(model_save_name))
|
|
|
|
|
+ with open(param_file_path, 'r') as file:
|
|
|
|
|
+ params = json.load(file)
|
|
|
|
|
+
|
|
|
|
|
+ optical_channel = OpticalChannel(fs=params["fs"],
|
|
|
|
|
+ num_of_samples=params["messages_per_block"] * params["samples_per_symbol"],
|
|
|
|
|
+ dispersion_factor=params["dispersion_factor"],
|
|
|
|
|
+ fiber_length=params["fiber_length"],
|
|
|
|
|
+ fiber_length_stddev=params["fiber_length_stddev"],
|
|
|
|
|
+ lpf_cutoff=params["lpf_cutoff"],
|
|
|
|
|
+ rx_stddev=params["rx_stddev"],
|
|
|
|
|
+ sig_avg=params["sig_avg"],
|
|
|
|
|
+ enob=params["enob"])
|
|
|
|
|
+
|
|
|
|
|
+ ae_model = EndToEndAutoencoder(cardinality=params["cardinality"],
|
|
|
|
|
+ samples_per_symbol=params["samples_per_symbol"],
|
|
|
|
|
+ messages_per_block=params["messages_per_block"],
|
|
|
|
|
+ channel=optical_channel,
|
|
|
|
|
+ custom_loss_fn=params["custom_loss_fn"])
|
|
|
|
|
+
|
|
|
|
|
+ if os.path.isfile(param_file_path) and not force_training:
|
|
|
|
|
+ ae_model.encoder = tf.keras.models.load_model(os.path.join("exports", model_save_name, "encoder"))
|
|
|
|
|
+ ae_model.decoder = tf.keras.models.load_model(os.path.join("exports", model_save_name, "decoder"))
|
|
|
|
|
+ else:
|
|
|
|
|
+ ae_model.train(num_of_blocks=1e4)
|
|
|
|
|
+ ae_model.save_end_to_end()
|
|
|
|
|
|
|
|
- ae_model.train(num_of_blocks=1e6, batch_size=100)
|
|
|
|
|
ae_model.view_encoder()
|
|
ae_model.view_encoder()
|
|
|
- ae_model.view_sample_block()
|
|
|
|
|
|
|
+ ae_model.test()
|
|
|
|
|
+
|
|
|
|
|
+ # cat = [np.arange(32)]
|
|
|
|
|
+ # enc = OneHotEncoder(handle_unknown='ignore', sparse=False, categories=cat)
|
|
|
|
|
+ #
|
|
|
|
|
+ # inp = np.asarray([9, 28, 15, 18, 23, 0, 29, 30, 2]).reshape(-1, 1)
|
|
|
|
|
+ # inp_oh = enc.fit_transform(inp)
|
|
|
|
|
+ #
|
|
|
|
|
+ # out = ae_model(inp_oh.reshape(1, 9, 32))
|
|
|
|
|
+ #
|
|
|
|
|
+ # a = out.numpy()
|
|
|
|
|
+ #
|
|
|
|
|
+ # plt.plot(a)
|
|
|
|
|
+ # plt.show()
|
|
|
|
|
|
|
|
pass
|
|
pass
|