ソースを参照

Merge branch 'matched_filter'

# Conflicts:
#	main.py
Min 5 年 前
コミット
7d5831b735
3 ファイル変更124 行追加25 行削除
  1. 5 0
      alphabets/4pam.a
  2. 58 9
      main.py
  3. 61 16
      models/optical_channel.py

+ 5 - 0
alphabets/4pam.a

@@ -0,0 +1,5 @@
+R
+00,0,0
+01,0.25,0
+10,0.5,0
+11,1,0

+ 58 - 9
main.py

@@ -7,6 +7,7 @@ from models.basic import AWGNChannel, BPSKDemod, BPSKMod, BypassChannel, Alphabe
 import misc
 import math
 from models.autoencoder import Autoencoder, view_encoder
+from models.optical_channel import OpticalChannel
 
 
 def show_constellation(mod, chan, demod, samples=1000):
@@ -56,6 +57,16 @@ def get_SNR(mod, demod, samples=1000, start=-8, stop=5, steps=30):
     SNR = (ber_x * -1) + av_sig_pow
     return SNR, ber_y
 
+def get_Optical_ber(mod, demod, samples=1000, start=-8, stop=5, steps=30, length=100, pulse_shape='rect'):
+    ber_x = np.linspace(start, stop, steps)
+    ber_y = []
+
+    for noise in ber_x:
+        tx_channel = OpticalChannel(noise_level=noise, dispersion=-21.7, symbol_rate=10e9, sample_rate=400e9,
+                                    length=length, pulse_shape=pulse_shape, sqrt_out=True)
+        ber_y.append(get_ber(mod, tx_channel, demod, samples=samples))
+    return ber_x, ber_y
+
 if __name__ == '__main__':
     # show_constellation(BPSKMod(10e6), AWGNChannel(-1), BPSKDemod(10e6, 10e3))
 
@@ -121,20 +132,58 @@ if __name__ == '__main__':
     # plt.plot(*get_AWGN_ber(aenc.get_modulator(), aenc.get_demodulator(), samples=12000, start=-15), '-',
     #          label='AE 4bit -8dB')
 
-    for scheme in ['64qam', '32qam', '16qam', 'qpsk', '8psk']:
-        plt.plot(*get_SNR(
-            AlphabetMod(scheme, 10e6),
-            AlphabetDemod(scheme, 10e6),
-            samples=100e3,
+    # for scheme in ['64qam', '32qam', '16qam', 'qpsk', '8psk']:
+    #     plt.plot(*get_SNR(
+    #         AlphabetMod(scheme, 10e6),
+    #         AlphabetDemod(scheme, 10e6),
+    #         samples=100e3,
+    #         steps=40,
+    #         start=-15
+    #     ), '-', label=scheme.upper())
+    # plt.yscale('log')
+    # plt.grid()
+    # plt.xlabel('SNR dB')
+    # plt.ylabel('BER')
+    # plt.legend()
+    # plt.show()
+
+    for l in np.logspace(start=0, stop=3, num=5):
+        plt.plot(*get_Optical_ber(
+            AlphabetMod('4pam', 10e6),
+            AlphabetDemod('4pam', 10e6),
+            samples=1000,
+            steps=40,
+            start=-15,
+            length=l,
+            pulse_shape='rcos'
+        ), '-', label=(str(int(l))+'km'))
+
+    plt.yscale('log')
+    plt.gca().invert_xaxis()
+    plt.grid()
+    plt.xlabel('Noise dB')
+    plt.ylabel('BER')
+    plt.title("BER against Fiber length")
+    plt.legend()
+    plt.show()
+
+    for ps in ['rect', 'rcos', 'rrcos']:
+        plt.plot(*get_Optical_ber(
+            AlphabetMod('4pam', 10e6),
+            AlphabetDemod('4pam', 10e6),
+            samples=1000,
             steps=40,
-            start=-15
-        ), '-', label=scheme.upper())
+            start=-15,
+            length=10,
+            pulse_shape=ps
+        ), '-', label=ps)
 
     plt.yscale('log')
-    # plt.gca().invert_xaxis()
+    plt.gca().invert_xaxis()
     plt.grid()
-    plt.xlabel('SNR dB')
+    plt.xlabel('Noise dB')
     plt.ylabel('BER')
+    plt.title("BER for different pulse shapes")
     plt.legend()
     plt.show()
 

+ 61 - 16
models/optical_channel.py

@@ -3,17 +3,20 @@ import matplotlib.pyplot as plt
 import defs
 import numpy as np
 import math
-from scipy.fft import fft, ifft
-
+from numpy.fft import fft, fftfreq, ifft
+from commpy.filters import rrcosfilter, rcosfilter, rectfilter
 
 class OpticalChannel(defs.Channel):
-    def __init__(self, noise_level, dispersion, symbol_rate, sample_rate, length, show_graphs=False, **kwargs):
+    def __init__(self, noise_level, dispersion, symbol_rate, sample_rate, length, pulse_shape='rect',
+                 sqrt_out=False, show_graphs=False, **kwargs):
         """
         :param noise_level: Noise level in dB
         :param dispersion: dispersion coefficient is ps^2/km
         :param symbol_rate: Symbol rate of modulated signal in Hz
         :param sample_rate: Sample rate of time-domain model (time steps in simulation) in Hz
         :param length: fibre length in km
+        :param pulse_shape: pulse shape -> ['rect', 'rcos', 'rrcos']
+        :param sqrt_out: Take the root of the out to compensate for photodiode detection
         :param show_graphs: if graphs should be displayed or not
 
         Optical Channel class constructor
@@ -21,29 +24,45 @@ class OpticalChannel(defs.Channel):
         super().__init__(**kwargs)
         self.noise = 10 ** (noise_level / 10)
 
-        self.dispersion = dispersion # * 1e-24  # Converting from ps^2/km to s^2/km
+        self.dispersion = dispersion * 1e-24  # Converting from ps^2/km to s^2/km
         self.symbol_rate = symbol_rate
         self.symbol_period = 1 / self.symbol_rate
         self.sample_rate = sample_rate
         self.sample_period = 1 / self.sample_rate
         self.length = length
+        self.pulse_shape = pulse_shape.strip().lower()
+        self.sqrt_out = sqrt_out
         self.show_graphs = show_graphs
 
     def __get_time_domain(self, symbol_vals):
         samples_per_symbol = int(self.sample_rate / self.symbol_rate)
         samples = int(symbol_vals.shape[0] * samples_per_symbol)
 
-        symbol_vals_a = np.repeat(symbol_vals, repeats=samples_per_symbol, axis=0)
-        t = np.linspace(start=0, stop=samples * self.sample_period, num=samples)
-        val_t = symbol_vals_a[:, 0] * np.cos(2 * math.pi * symbol_vals_a[:, 2] * t + symbol_vals_a[:, 1])
+        symbol_impulse = np.zeros(samples)
+
+        # TODO: Implement Frequency/Phase Modulation
+
+        for i in range(symbol_vals.shape[0]):
+            symbol_impulse[i*samples_per_symbol] = symbol_vals[i, 0]
+
+        if self.pulse_shape == 'rrcos':
+            self.filter_samples = 5 * samples_per_symbol
+            self.t_filter, self.h_filter = rrcosfilter(self.filter_samples, 0.8, self.symbol_period, self.sample_rate)
+        elif self.pulse_shape == 'rcos':
+            self.filter_samples = 5 * samples_per_symbol
+            self.t_filter, self.h_filter = rcosfilter(self.filter_samples, 0.8, self.symbol_period, self.sample_rate)
+        else:
+            self.filter_samples = samples_per_symbol
+            self.t_filter, self.h_filter = rectfilter(self.filter_samples, self.symbol_period, self.sample_rate)
+
+        val_t = np.convolve(symbol_impulse, self.h_filter)
+        t = np.linspace(start=0, stop=val_t.shape[0] * self.sample_period, num=val_t.shape[0])
 
         return t, val_t
 
     def __time_to_frequency(self, values):
         val_f = fft(values)
-        f = np.linspace(0.0, 1 / (2 * self.sample_period), (values.size // 2))
-        f_neg = -1 * np.flip(f)
-        f = np.concatenate((f, f_neg), axis=0)
+        f = fftfreq(values.shape[-1])*self.sample_rate
         return f, val_f
 
     def __frequency_to_time(self, values):
@@ -106,22 +125,48 @@ class OpticalChannel(defs.Channel):
         # Photodiode Detection
         t, val_t = self.__photodiode_detection(val_t)
 
+        # Symbol Decisions
+        idx = np.arange(self.filter_samples/2, t.shape[0] - (self.filter_samples/2),
+                        self.symbol_period/self.sample_period, dtype='int16')
+        t_descision = self.sample_period * idx
+
         if self.show_graphs:
             plt.plot(t, val_t)
             plt.title('time domain (post-detection)')
             plt.show()
 
-        return t, val_t
+            plt.plot(t, val_t)
+            for xc in t_descision:
+                plt.axvline(x=xc, color='r')
+            plt.title('time domain (post-detection with decision times)')
+            plt.show()
+
+        # TODO: Implement Frequency/Phase Modulation
+
+        out = np.zeros(values.shape)
+
+        out[:, 0] = val_t[idx]
+        out[:, 1] = values[:, 1]
+        out[:, 2] = values[:, 2]
+
+        if self.sqrt_out:
+            out[:, 0] = np.sqrt(out[:, 0])
+
+        return out
 
 
 if __name__ == '__main__':
     # Simple OOK modulation
-    num_of_symbols = 10
+    num_of_symbols = 100
     symbol_vals = np.zeros((num_of_symbols, 3))
 
     symbol_vals[:, 0] = np.random.randint(2, size=symbol_vals.shape[0])
-    symbol_vals[:, 2] = 10e6
+    symbol_vals[:, 2] = 40e9
+
+    channel = OpticalChannel(noise_level=-10, dispersion=-21.7, symbol_rate=10e9,
+                             sample_rate=400e9, length=100, pulse_shape='rcos', show_graphs=True)
+    v = channel.forward(symbol_vals)
 
-    channel = OpticalChannel(noise_level=-20, dispersion=-21.7, symbol_rate=100e3,
-                             sample_rate=500e6, length=100, show_graphs=True)
-    time, v = channel.forward(symbol_vals)
+    rx = (v > 0.5).astype(int)
+    tru = np.sum(rx == symbol_vals[:, 0].astype(int))
+    print("Accuracy: {}".format(tru/num_of_symbols))