Tharmetharan Balendran 4 anos atrás
pai
commit
c0e372170d
2 arquivos alterados com 42 adições e 3 exclusões
  1. 31 1
      models/end_to_end.py
  2. 11 2
      models/layers.py

+ 31 - 1
models/end_to_end.py

@@ -524,7 +524,6 @@ def run_tests(distance=50):
         mode='auto', baseline=None, restore_best_weights=True
         mode='auto', baseline=None, restore_best_weights=True
     )
     )
 
 
-
     # model_checkpoint_callback1 = tf.keras.callbacks.ModelCheckpoint(
     # model_checkpoint_callback1 = tf.keras.callbacks.ModelCheckpoint(
     #     filepath='/tmp/checkpoint/quantised',
     #     filepath='/tmp/checkpoint/quantised',
     #     save_weights_only=True,
     #     save_weights_only=True,
@@ -574,6 +573,37 @@ def run_tests(distance=50):
 
 
 
 
 if __name__ == '__main__':
 if __name__ == '__main__':
+    params = {
+        "fs": 336e9,
+        "cardinality": 64,
+        "samples_per_symbol": 48,
+        "messages_per_block": 9,
+        "dispersion_factor": (-21.7 * 1e-24),
+        "fiber_length": 20,
+        "fiber_length_stddev": 1,
+        "lpf_cutoff": 32e9,
+        "rx_stddev": 0.13,
+        "sig_avg": 0.5,
+        "enob": 6,
+        "custom_loss_fn": True
+    }
+
+    optical_channel = OpticalChannel(
+        fs=params["fs"],
+        num_of_samples=params["messages_per_block"] * params["samples_per_symbol"],
+        dispersion_factor=params["dispersion_factor"],
+        fiber_length=params["fiber_length"],
+        fiber_length_stddev=params["fiber_length_stddev"],
+        lpf_cutoff=params["lpf_cutoff"],
+        rx_stddev=params["rx_stddev"],
+        sig_avg=params["sig_avg"],
+        enob=params["enob"],
+    )
+
+    print(optical_channel.compute_snr())
+
+
+if __name__ == 'asd':
     data0 = run_tests(90)
     data0 = run_tests(90)
     # data1 = run_tests(70)
     # data1 = run_tests(70)
     # data2 = run_tests(80)
     # data2 = run_tests(80)

+ 11 - 2
models/layers.py

@@ -144,7 +144,8 @@ class OpticalChannel(layers.Layer):
                  lpf_cutoff=32e9,
                  lpf_cutoff=32e9,
                  rx_stddev=0.01,
                  rx_stddev=0.01,
                  sig_avg=0.5,
                  sig_avg=0.5,
-                 enob=10):
+                 enob=10,
+                 atten=0.2):
         """
         """
         A channel model that simulates chromatic dispersion, non-linear photodiode detection, finite bandwidth of
         A channel model that simulates chromatic dispersion, non-linear photodiode detection, finite bandwidth of
         ADC/DAC as well as additive white gaussian noise in optical communication channels.
         ADC/DAC as well as additive white gaussian noise in optical communication channels.
@@ -168,6 +169,7 @@ class OpticalChannel(layers.Layer):
         self.rx_stddev = rx_stddev
         self.rx_stddev = rx_stddev
         self.sig_avg = sig_avg
         self.sig_avg = sig_avg
         self.enob = enob
         self.enob = enob
+        self.total_atten = tf.cast(10**(-0.1*atten*fiber_length), dtype=tf.float32)
 
 
         self.noise_layer = layers.GaussianNoise(self.rx_stddev)
         self.noise_layer = layers.GaussianNoise(self.rx_stddev)
         self.fiber_length_noise = layers.GaussianNoise(self.fiber_length_stddev)
         self.fiber_length_noise = layers.GaussianNoise(self.fiber_length_stddev)
@@ -181,6 +183,12 @@ class OpticalChannel(layers.Layer):
         self.freq = tf.convert_to_tensor(np.fft.fftfreq(self.num_of_samples, d=1/fs), dtype=tf.complex64)
         self.freq = tf.convert_to_tensor(np.fft.fftfreq(self.num_of_samples, d=1/fs), dtype=tf.complex64)
         # self.multiplier = tf.math.exp(0.5j*self.dispersion_factor*self.fiber_length*tf.math.square(2*math.pi*self.freq))
         # self.multiplier = tf.math.exp(0.5j*self.dispersion_factor*self.fiber_length*tf.math.square(2*math.pi*self.freq))
 
 
+    def compute_snr(self):
+        average_optical_power = (2*self.sig_avg*self.total_atten) / 3
+        signal_var = self.rx_stddev**2
+
+        return 10*math.log(average_optical_power/signal_var)
+
     def call(self, inputs, **kwargs):
     def call(self, inputs, **kwargs):
         # DAC LPF and noise
         # DAC LPF and noise
         dac_out = self.digitization_layer(inputs)
         dac_out = self.digitization_layer(inputs)
@@ -191,7 +199,8 @@ class OpticalChannel(layers.Layer):
 
 
         len = tf.cast(self.fiber_length_noise.call(self.fiber_length), dtype=tf.complex64)
         len = tf.cast(self.fiber_length_noise.call(self.fiber_length), dtype=tf.complex64)
 
 
-        multiplier = tf.math.exp(0.5j*self.dispersion_factor*len*tf.math.square(2*math.pi*self.freq))
+        # Applying fiber dispersion and attenuation
+        multiplier = self.total_atten * tf.math.exp(0.5j*self.dispersion_factor*len*tf.math.square(2*math.pi*self.freq))
 
 
         disp_f = tf.math.multiply(val_f, multiplier)
         disp_f = tf.math.multiply(val_f, multiplier)
         disp_t = tf.signal.ifft(disp_f)
         disp_t = tf.signal.ifft(disp_f)