ソースを参照

ber plot for optical channel

Tharmetharan Balendran 5 年 前
コミット
6d04f26a97
3 ファイル変更72 行追加9 行削除
  1. 5 0
      alphabets/4pam.a
  2. 50 7
      main.py
  3. 17 2
      models/optical_channel.py

+ 5 - 0
alphabets/4pam.a

@@ -0,0 +1,5 @@
+R
+00,0,0
+01,0.25,0
+10,0.5,0
+11,1,0

+ 50 - 7
main.py

@@ -6,6 +6,7 @@ from models import basic
 from models.basic import AWGNChannel, BPSKDemod, BPSKMod, BypassChannel, AlphabetMod, AlphabetDemod
 from models.basic import AWGNChannel, BPSKDemod, BPSKMod, BypassChannel, AlphabetMod, AlphabetDemod
 import misc
 import misc
 from models.autoencoder import Autoencoder, view_encoder
 from models.autoencoder import Autoencoder, view_encoder
+from models.optical_channel import OpticalChannel
 
 
 
 
 def show_constellation(mod, chan, demod, samples=1000):
 def show_constellation(mod, chan, demod, samples=1000):
@@ -45,6 +46,16 @@ def get_AWGN_ber(mod, demod, samples=1000, start=-8, stop=5, steps=30):
     return ber_x, ber_y
     return ber_x, ber_y
 
 
 
 
+def get_Optical_ber(mod, demod, samples=1000, start=-8, stop=5, steps=30, length=100, pulse_shape='rect'):
+    ber_x = np.linspace(start, stop, steps)
+    ber_y = []
+
+    for noise in ber_x:
+        tx_channel = OpticalChannel(noise_level=noise, dispersion=-21.7, symbol_rate=10e9, sample_rate=400e9,
+                                    length=length, pulse_shape=pulse_shape, sqrt_out=True)
+        ber_y.append(get_ber(mod, tx_channel, demod, samples=samples))
+    return ber_x, ber_y
+
 if __name__ == '__main__':
 if __name__ == '__main__':
     # show_constellation(BPSKMod(10e6), AWGNChannel(-1), BPSKDemod(10e6, 10e3))
     # show_constellation(BPSKMod(10e6), AWGNChannel(-1), BPSKDemod(10e6, 10e3))
 
 
@@ -110,20 +121,52 @@ if __name__ == '__main__':
     # plt.plot(*get_AWGN_ber(aenc.get_modulator(), aenc.get_demodulator(), samples=12000, start=-15), '-',
     # plt.plot(*get_AWGN_ber(aenc.get_modulator(), aenc.get_demodulator(), samples=12000, start=-15), '-',
     #          label='AE 4bit -8dB')
     #          label='AE 4bit -8dB')
 
 
-    for scheme in ['64qam', '32qam', '16qam', 'qpsk', '8psk']:
-        plt.plot(*get_AWGN_ber(
-            AlphabetMod(scheme, 10e6),
-            AlphabetDemod(scheme, 10e6),
-            samples=20e3,
+    # for scheme in ['64qam', '32qam', '16qam', 'qpsk', '8psk']:
+    #     plt.plot(*get_AWGN_ber(
+    #         AlphabetMod(scheme, 10e6),
+    #         AlphabetDemod(scheme, 10e6),
+    #         samples=20e3,
+    #         steps=40,
+    #         start=-15
+    #     ), '-', label=scheme.upper())
+
+    for l in np.logspace(start=0, stop=3, num=5):
+        plt.plot(*get_Optical_ber(
+            AlphabetMod('4pam', 10e6),
+            AlphabetDemod('4pam', 10e6),
+            samples=1000,
+            steps=40,
+            start=-15,
+            length=l,
+            pulse_shape='rcos'
+        ), '-', label=(str(int(l))+'km'))
+
+    plt.yscale('log')
+    plt.gca().invert_xaxis()
+    plt.grid()
+    plt.xlabel('Noise dB')
+    plt.ylabel('BER')
+    plt.title("BER against Fiber length")
+    plt.legend()
+    plt.show()
+
+    for ps in ['rect', 'rcos', 'rrcos']:
+        plt.plot(*get_Optical_ber(
+            AlphabetMod('4pam', 10e6),
+            AlphabetDemod('4pam', 10e6),
+            samples=1000,
             steps=40,
             steps=40,
-            start=-15
-        ), '-', label=scheme.upper())
+            start=-15,
+            length=10,
+            pulse_shape=ps
+        ), '-', label=ps)
 
 
     plt.yscale('log')
     plt.yscale('log')
     plt.gca().invert_xaxis()
     plt.gca().invert_xaxis()
     plt.grid()
     plt.grid()
     plt.xlabel('Noise dB')
     plt.xlabel('Noise dB')
     plt.ylabel('BER')
     plt.ylabel('BER')
+    plt.title("BER for different pulse shapes")
     plt.legend()
     plt.legend()
     plt.show()
     plt.show()
 
 

+ 17 - 2
models/optical_channel.py

@@ -8,7 +8,7 @@ from commpy.filters import rrcosfilter, rcosfilter, rectfilter
 
 
 class OpticalChannel(defs.Channel):
 class OpticalChannel(defs.Channel):
     def __init__(self, noise_level, dispersion, symbol_rate, sample_rate, length, pulse_shape='rect',
     def __init__(self, noise_level, dispersion, symbol_rate, sample_rate, length, pulse_shape='rect',
-                 show_graphs=False, **kwargs):
+                 sqrt_out=False, show_graphs=False, **kwargs):
         """
         """
         :param noise_level: Noise level in dB
         :param noise_level: Noise level in dB
         :param dispersion: dispersion coefficient is ps^2/km
         :param dispersion: dispersion coefficient is ps^2/km
@@ -16,6 +16,7 @@ class OpticalChannel(defs.Channel):
         :param sample_rate: Sample rate of time-domain model (time steps in simulation) in Hz
         :param sample_rate: Sample rate of time-domain model (time steps in simulation) in Hz
         :param length: fibre length in km
         :param length: fibre length in km
         :param pulse_shape: pulse shape -> ['rect', 'rcos', 'rrcos']
         :param pulse_shape: pulse shape -> ['rect', 'rcos', 'rrcos']
+        :param sqrt_out: Take the root of the out to compensate for photodiode detection
         :param show_graphs: if graphs should be displayed or not
         :param show_graphs: if graphs should be displayed or not
 
 
         Optical Channel class constructor
         Optical Channel class constructor
@@ -30,6 +31,7 @@ class OpticalChannel(defs.Channel):
         self.sample_period = 1 / self.sample_rate
         self.sample_period = 1 / self.sample_rate
         self.length = length
         self.length = length
         self.pulse_shape = pulse_shape.strip().lower()
         self.pulse_shape = pulse_shape.strip().lower()
+        self.sqrt_out = sqrt_out
         self.show_graphs = show_graphs
         self.show_graphs = show_graphs
 
 
     def __get_time_domain(self, symbol_vals):
     def __get_time_domain(self, symbol_vals):
@@ -38,6 +40,8 @@ class OpticalChannel(defs.Channel):
 
 
         symbol_impulse = np.zeros(samples)
         symbol_impulse = np.zeros(samples)
 
 
+        # TODO: Implement Frequency/Phase Modulation
+
         for i in range(symbol_vals.shape[0]):
         for i in range(symbol_vals.shape[0]):
             symbol_impulse[i*samples_per_symbol] = symbol_vals[i, 0]
             symbol_impulse[i*samples_per_symbol] = symbol_vals[i, 0]
 
 
@@ -137,7 +141,18 @@ class OpticalChannel(defs.Channel):
             plt.title('time domain (post-detection with decision times)')
             plt.title('time domain (post-detection with decision times)')
             plt.show()
             plt.show()
 
 
-        return val_t[idx]
+        # TODO: Implement Frequency/Phase Modulation
+
+        out = np.zeros(values.shape)
+
+        out[:, 0] = val_t[idx]
+        out[:, 1] = values[:, 1]
+        out[:, 2] = values[:, 2]
+
+        if self.sqrt_out:
+            out[:, 0] = np.sqrt(out[:, 0])
+
+        return out
 
 
 
 
 if __name__ == '__main__':
 if __name__ == '__main__':