Pārlūkot izejas kodu

fixed merging end_to_end

Min 4 gadi atpakaļ
vecāks
revīzija
e6940a5e2c
1 mainītis faili ar 2 papildinājumiem un 113 dzēšanām
  1. 2 113
      models/end_to_end.py

+ 2 - 113
models/end_to_end.py

@@ -1,3 +1,4 @@
+import itertools
 import math
 
 import tensorflow as tf
@@ -7,119 +8,7 @@ from sklearn.metrics import accuracy_score
 from sklearn.preprocessing import OneHotEncoder
 from tensorflow.keras import layers, losses
 
-
-class ExtractCentralMessage(layers.Layer):
-    def __init__(self, messages_per_block, samples_per_symbol):
-        """
-        :param messages_per_block: Total number of messages in transmission block
-        :param samples_per_symbol: Number of samples per transmitted symbol
-        """
-        super(ExtractCentralMessage, self).__init__()
-
-        temp_w = np.zeros((messages_per_block * samples_per_symbol, samples_per_symbol))
-        i = np.identity(samples_per_symbol)
-        begin = int(samples_per_symbol * ((messages_per_block - 1) / 2))
-        end = int(samples_per_symbol * ((messages_per_block + 1) / 2))
-        temp_w[begin:end, :] = i
-
-        self.w = tf.convert_to_tensor(temp_w, dtype=tf.float32)
-
-    def call(self, inputs, **kwargs):
-        return tf.matmul(inputs, self.w)
-
-
-class DigitizationLayer(layers.Layer):
-    def __init__(self,
-                 fs,
-                 num_of_samples,
-                 lpf_cutoff=32e9,
-                 q_stddev=0.1):
-        """
-        :param fs: Sampling frequency of the simulation in Hz
-        :param num_of_samples: Total number of samples in the input
-        :param lpf_cutoff: Cutoff frequency of LPF modelling finite bandwidth in ADC/DAC
-        :param q_stddev: Standard deviation of quantization noise at ADC/DAC
-        """
-        super(DigitizationLayer, self).__init__()
-
-        self.noise_layer = layers.GaussianNoise(q_stddev)
-        freq = np.fft.fftfreq(num_of_samples, d=1/fs)
-        temp = np.ones(freq.shape)
-
-        for idx, val in np.ndenumerate(freq):
-            if np.abs(val) > lpf_cutoff:
-                temp[idx] = 0
-
-        self.lpf_multiplier = tf.convert_to_tensor(temp, dtype=tf.complex64)
-
-    def call(self, inputs, **kwargs):
-        complex_in = tf.cast(inputs, dtype=tf.complex64)
-        val_f = tf.signal.fft(complex_in)
-        filtered_f = tf.math.multiply(self.lpf_multiplier, val_f)
-        filtered_t = tf.signal.ifft(filtered_f)
-        real_t = tf.cast(filtered_t, dtype=tf.float32)
-        noisy = self.noise_layer.call(real_t, training=True)
-        return noisy
-
-
-class OpticalChannel(layers.Layer):
-    def __init__(self,
-                 fs,
-                 num_of_samples,
-                 dispersion_factor,
-                 fiber_length,
-                 lpf_cutoff=32e9,
-                 rx_stddev=0.01,
-                 q_stddev=0.01):
-        """
-        :param fs: Sampling frequency of the simulation in Hz
-        :param num_of_samples: Total number of samples in the input
-        :param dispersion_factor: Dispersion factor in s^2/km
-        :param fiber_length: Length of fiber to model in km
-        :param lpf_cutoff: Cutoff frequency of LPF modelling finite bandwidth in ADC/DAC
-        :param rx_stddev: Standard deviation of receiver noise (due to e.g. TIA circuit)
-        :param q_stddev: Standard deviation of quantization noise at ADC/DAC
-        """
-        super(OpticalChannel, self).__init__()
-
-        self.noise_layer = layers.GaussianNoise(rx_stddev)
-        self.digitization_layer = DigitizationLayer(fs=fs,
-                                                    num_of_samples=num_of_samples,
-                                                    lpf_cutoff=lpf_cutoff,
-                                                    q_stddev=q_stddev)
-        self.flatten_layer = layers.Flatten()
-
-        self.fs = fs
-        self.freq = tf.convert_to_tensor(np.fft.fftfreq(num_of_samples, d=1/fs), dtype=tf.complex128)
-        self.multiplier = tf.math.exp(0.5j*dispersion_factor*fiber_length*tf.math.square(2*math.pi*self.freq))
-
-    def call(self, inputs, **kwargs):
-        # DAC LPF and noise
-        dac_out = self.digitization_layer(inputs)
-
-        # Chromatic Dispersion
-        complex_val = tf.cast(dac_out, dtype=tf.complex128)
-        val_f = tf.signal.fft(complex_val)
-        disp_f = tf.math.multiply(val_f, self.multiplier)
-        disp_t = tf.signal.ifft(disp_f)
-
-        # Squared-Law Detection
-        pd_out = tf.square(tf.abs(disp_t))
-
-        # Casting back to floatx
-        real_val = tf.cast(pd_out, dtype=tf.float32)
-
-        # Adding photo-diode receiver noise
-        rx_signal = self.noise_layer.call(real_val, training=True)
-
-        # ADC LPF and noise
-        adc_out = self.digitization_layer(rx_signal)
-
-        return adc_out
-from tensorflow.keras import layers, losses
-from tensorflow.keras import backend as K
-from custom_layers import ExtractCentralMessage, OpticalChannel, DigitizationLayer, BitsToSymbols, SymbolsToBits
-import itertools
+from layers import ExtractCentralMessage, BitsToSymbols, SymbolsToBits, OpticalChannel, DigitizationLayer
 
 
 class EndToEndAutoencoder(tf.keras.Model):