浏览代码

OFDM class with decision making

Tharmetharan Balendran 4 年之前
父节点
当前提交
c93e79f8ac
共有 2 个文件被更改,包括 291 次插入356 次删除
  1. 157 299
      models/OFDM.py
  2. 134 57
      tests/ofdm_test.py

+ 157 - 299
models/OFDM.py

@@ -1,303 +1,161 @@
 import numpy as np
-import matplotlib.pyplot as plt
-from scipy import interpolate
-
+import tensorflow as tf
+from tensorflow.keras import layers, losses
+from matplotlib import pyplot as plt
+from models.basic import AlphabetMod, AlphabetDemod, RFSignal
 from models.custom_layers import OpticalChannel
 
-SAMPLING_FREQUENCY = 336e9
-CARDINALITY = 4
-SAMPLES_PER_SYMBOL = 128
-MESSAGES_PER_BLOCK = 9
-DISPERSION_FACTOR = -21.7 * 1e-24
-FIBER_LENGTH = 50
-
-optical_channel = OpticalChannel(fs=SAMPLING_FREQUENCY,
-                                 num_of_samples=80,
-                                 dispersion_factor=DISPERSION_FACTOR,
-                                 fiber_length=FIBER_LENGTH)
-
-K = 64  # number of OFDM subcarriers
-CP = K // 4  # length of the cyclic prefix: 25% of the block
-P = 8  # number of pilot carriers per OFDM block
-pilotValue = 3 + 3j  # The known value each pilot transmits
-
-allCarriers = np.arange(K)  # indices of all subcarriers ([0, 1, ... K-1])
-
-pilotCarriers = allCarriers[::K // P]  # Pilots is every (K/P)th carrier.
-
-# For convenience of channel estimation, let's make the last carriers also be a pilot
-pilotCarriers = np.hstack([pilotCarriers, np.array([allCarriers[-1]])])
-P = P + 1
-
-# data carriers are all remaining carriers
-dataCarriers = np.delete(allCarriers, pilotCarriers)
-
-print("allCarriers:   %s" % allCarriers)
-print("pilotCarriers: %s" % pilotCarriers)
-print("dataCarriers:  %s" % dataCarriers)
-plt.plot(pilotCarriers, np.zeros_like(pilotCarriers), 'bo', label='pilot')
-plt.plot(dataCarriers, np.zeros_like(dataCarriers), 'ro', label='data')
-
-mu = 4  # bits per symbol (i.e. 16QAM)
-payloadBits_per_OFDM = len(dataCarriers) * mu  # number of payload bits per OFDM symbol
-
-mapping_table = {
-    (0, 0, 0, 0): -3 - 3j,
-    (0, 0, 0, 1): -3 - 1j,
-    (0, 0, 1, 0): -3 + 3j,
-    (0, 0, 1, 1): -3 + 1j,
-    (0, 1, 0, 0): -1 - 3j,
-    (0, 1, 0, 1): -1 - 1j,
-    (0, 1, 1, 0): -1 + 3j,
-    (0, 1, 1, 1): -1 + 1j,
-    (1, 0, 0, 0): 3 - 3j,
-    (1, 0, 0, 1): 3 - 1j,
-    (1, 0, 1, 0): 3 + 3j,
-    (1, 0, 1, 1): 3 + 1j,
-    (1, 1, 0, 0): 1 - 3j,
-    (1, 1, 0, 1): 1 - 1j,
-    (1, 1, 1, 0): 1 + 3j,
-    (1, 1, 1, 1): 1 + 1j
-}
-
-mapping_table_dec = {
-    (0): -3 - 3j,
-    (1): -3 - 1j,
-    (2): -3 + 3j,
-    (3): -3 + 1j,
-    (4): -1 - 3j,
-    (5): -1 - 1j,
-    (6): -1 + 3j,
-    (7): -1 + 1j,
-    (8): 3 - 3j,
-    (9): 3 - 1j,
-    (10): 3 + 3j,
-    (11): 3 + 1j,
-    (12): 1 - 3j,
-    (13): 1 - 1j,
-    (14): 1 + 3j,
-    (15): 1 + 1j
-}
-for b3 in [0, 1]:
-    for b2 in [0, 1]:
-        for b1 in [0, 1]:
-            for b0 in [0, 1]:
-                B = (b3, b2, b1, b0)
-                Q = mapping_table[B]
-                plt.plot(Q.real, Q.imag, 'bo')
-                plt.text(Q.real, Q.imag + 0.2, "".join(str(x) for x in B), ha='center')
-
-demapping_table = {v: k for k, v in mapping_table.items()}
-
-# Replace with our channel
-channelResponse = np.array([1, 0, 0.3 + 0.3j])  # the impulse response of the wireless channel
-H_exact = np.fft.fft(channelResponse, K)
-plt.plot(allCarriers, abs(H_exact))
-
-SNRdb = 25  # signal to noise-ratio in dB at the receiver
-
-# Here
-
-# water filling, gradient decent methods for optimising the symbol mapping, instead of 16 QAM
-
-bits = np.random.binomial(n=1, p=0.5, size=(payloadBits_per_OFDM,))
-print("Bits count: ", len(bits))
-print("First 20 bits: ", bits[:20])
-print("Mean of bits (should be around 0.5): ", np.mean(bits))
-
-
-def SP(bits):
-    return bits.reshape((len(dataCarriers), mu))
-
-
-bits_SP = SP(bits)
-print("First 5 bit groups")
-print(bits_SP[:5, :])
-
-
-def generate_random_inputs(num_of_blocks, return_vals=False):
-    """
-    A method that generates a list of one-hot encoded messages. This is utilized for generating the test/train data.
-
-    :param num_of_blocks: Number of blocks to generate. A block contains multiple messages to be transmitted in
-    consecutively to model ISI. The central message in a block is returned as the label for training.
-    :param return_vals: If true, the raw decimal values of the input sequence will be returned
-    """
-    rand_int = np.random.randint(16, size=(num_of_blocks, 1))
-
-    # cat = [np.arange(self.cardinality)]
-    # enc = OneHotEncoder(handle_unknown='ignore', sparse=False, categories=cat)
-
-    # out = enc.fit_transform(rand_int)
-    # for symbol in rand_int:
-
-    # out_arr = np.reshape(rand_int, (num_of_blocks, self.messages_per_block))
-    # t_out_arr = np.repeat(out_arr, self.samples_per_symbol, axis=1)
-
-    # mid_idx = int((self.messages_per_block - 1) / 2)
-
-    # if return_vals:
-    # out_val = np.reshape(rand_int, (num_of_blocks, self.messages_per_block, 1))
-    #  return out_val, out_arr, out_arr[:, mid_idx, :]
-
-    return rand_int
-
-
-bits_SP1 = generate_random_inputs(num_of_blocks=1000).astype('uint8')
-
-
-# bits_SP11 = np.unpackbits(bits_SP1, axis=1)
-
-def Mapping(bits):
-    return np.array([mapping_table[tuple(b)] for b in bits])
-
-
-def Mapping_dec(bits):
-    return np.array([mapping_table_dec[tuple(b)] for b in bits])
-
-
-QAM = Mapping(bits_SP)
-QAM1 = Mapping_dec(bits_SP1)
-print("First 5 QAM symbols and bits:")
-print(bits_SP[:5, :])
-print(QAM[:5])
-
-
-def OFDM_symbol(QAM_payload):
-    symbol = np.zeros(K, dtype=complex)  # the overall K subcarriers
-    symbol[pilotCarriers] = pilotValue  # allocate the pilot subcarriers
-    symbol[dataCarriers] = QAM_payload  # allocate the pilot subcarriers
-    return symbol
-
-
-OFDM_data = OFDM_symbol(QAM)
-print("Number of OFDM carriers in frequency domain: ", len(OFDM_data))
-
-
-def IDFT(OFDM_data):
-    return np.fft.ifft(OFDM_data)
-
-
-OFDM_time = IDFT(OFDM_data)
-print("Number of OFDM samples in time-domain before CP: ", len(OFDM_time))
-
-
-def addCP(OFDM_time):
-    cp = OFDM_time[-CP:]  # take the last CP samples ...
-    return np.hstack([cp, OFDM_time])  # ... and add them to the beginning
-
-
-OFDM_withCP = addCP(OFDM_time)
-print("Number of OFDM samples in time domain with CP: ", len(OFDM_withCP))
-
-
-def channel(signal):
-    convolved = np.convolve(signal, channelResponse)
-    signal_power = np.mean(abs(convolved ** 2))
-    sigma2 = signal_power * 10 ** (-SNRdb / 10)  # calculate noise power based on signal power and SNR
-
-    print("RX Signal power: %.4f. Noise power: %.4f" % (signal_power, sigma2))
-
-    # Generate complex noise with given variance
-    noise = np.sqrt(sigma2 / 2) * (np.random.randn(*convolved.shape) + 1j * np.random.randn(*convolved.shape))
-    return convolved + noise
-
-
-OFDM_TX = OFDM_withCP
-OFDM_RX = channel(OFDM_TX)
-# OFDM_RX1 = optical_channel(OFDM_TX).numpy
-plt.figure(figsize=(8, 2))
-plt.plot(abs(OFDM_TX), label='TX signal')
-plt.plot(abs(OFDM_RX), label='RX signal')
-plt.legend(fontsize=10)
-plt.xlabel('Time')
-plt.ylabel('$|x(t)|$')
-plt.grid(True)
-
-
-def removeCP(signal):
-    return signal[CP:(CP + K)]
-
-
-OFDM_RX_noCP = removeCP(OFDM_RX)
-
-
-def DFT(OFDM_RX):
-    return np.fft.fft(OFDM_RX)
-
-
-OFDM_demod = DFT(OFDM_RX_noCP)
-
-
-def channelEstimate(OFDM_demod):
-    pilots = OFDM_demod[pilotCarriers]  # extract the pilot values from the RX signal
-    Hest_at_pilots = pilots / pilotValue  # divide by the transmitted pilot values
-
-    # Perform interpolation between the pilot carriers to get an estimate
-    # of the channel in the data carriers. Here, we interpolate absolute value and phase
-    # separately
-    Hest_abs = interpolate.interp1d(pilotCarriers, abs(Hest_at_pilots), kind='linear')(allCarriers)
-    Hest_phase = interpolate.interp1d(pilotCarriers, np.angle(Hest_at_pilots), kind='linear')(allCarriers)
-    Hest = Hest_abs * np.exp(1j * Hest_phase)
-
-    plt.plot(allCarriers, abs(H_exact), label='Correct Channel')
-    plt.stem(pilotCarriers, abs(Hest_at_pilots), label='Pilot estimates')
-    plt.plot(allCarriers, abs(Hest), label='Estimated channel via interpolation')
-    plt.grid(True)
-    plt.xlabel('Carrier index')
-    plt.ylabel('$|H(f)|$')
-    plt.legend(fontsize=10)
-    plt.ylim(0, 2)
-
-    return Hest
-
-
-Hest = channelEstimate(OFDM_demod)
-
-
-def equalize(OFDM_demod, Hest):
-    return OFDM_demod / Hest
-
-
-equalized_Hest = equalize(OFDM_demod, Hest)
-
-
-def get_payload(equalized):
-    return equalized[dataCarriers]
-
-
-QAM_est = get_payload(equalized_Hest)
-plt.plot(QAM_est.real, QAM_est.imag, 'bo');
-
-
-def Demapping(QAM):
-    # array of possible constellation points
-    constellation = np.array([x for x in demapping_table.keys()])
-
-    # calculate distance of each RX point to each possible point
-    dists = abs(QAM.reshape((-1, 1)) - constellation.reshape((1, -1)))
-
-    # for each element in QAM, choose the index in constellation
-    # that belongs to the nearest constellation point
-    const_index = dists.argmin(axis=1)
-
-    # get back the real constellation point
-    hardDecision = constellation[const_index]
-
-    # transform the constellation point into the bit groups
-    return np.vstack([demapping_table[C] for C in hardDecision]), hardDecision
-
-
-PS_est, hardDecision = Demapping(QAM_est)
-for qam, hard in zip(QAM_est, hardDecision):
-    plt.plot([qam.real, hard.real], [qam.imag, hard.imag], 'b-o');
-    plt.plot(hardDecision.real, hardDecision.imag, 'ro')
-
-
-def PS(bits):
-    return bits.reshape((-1,))
-
-
-bits_est = PS(PS_est)
 
-print("Obtained Bit error rate: ", np.sum(abs(bits - bits_est)) / len(bits))
+class OFDM():
+
+    def __init__(self, num_channels, channel_args, **kwargs):
+        self.num_channels = num_channels
+        self.channel_args = channel_args
+
+        if "dc_offset" in kwargs:
+            self.dc_offset = kwargs["dc_offset"]
+        else:
+            self.dc_offset = 50
+
+        self.mod_schemes = None
+        self.modulators = None
+        self.demodulators = None
+        self.bits_per_ofdm_symbol = None
+        self.set_mod_schemes()
+
+        self.tx_bits = None
+
+    def set_mod_schemes(self, new_mod_schemes=None):
+        if new_mod_schemes is None:
+            self.mod_schemes = ["qpsk" for x in range(self.num_channels)]
+        else:
+            self.mod_schemes = new_mod_schemes
+        self.modulators = {x: AlphabetMod(x, 0.0) for x in list(set(self.mod_schemes))}
+        self.demodulators = {x: AlphabetDemod(x, 0.0) for x in list(set(self.mod_schemes))}
+        self.bits_per_ofdm_symbol = 0
+
+        for mod in self.mod_schemes:
+            self.bits_per_ofdm_symbol += self.modulators[mod].N
+
+    def transmit_and_equalize(self, tx_stream):
+
+        optical_channel = OpticalChannel(**self.channel_args, num_of_samples=len(tx_stream))
+
+        rx_stream = optical_channel(tx_stream)
+        rx_stream_f = np.fft.fft(np.sqrt(rx_stream))
+        f = np.fft.fftfreq(rx_stream_f.shape[0], d=1/(self.channel_args["fs"]))
+
+        # TODO : Equalization needs to be tested
+        mul = np.exp(-0.5j*(self.channel_args["dispersion_factor"])*(self.channel_args["fiber_length"])*np.square(2*np.pi*f))
+        rx_compensated_f = rx_stream_f * mul
+        rx_compensated = np.abs(np.fft.ifft(rx_compensated_f))
+        return rx_compensated
+
+    def find_optimal_schemes(self):
+        # TODO : Choosing approapriate modulation scheme after a single pass of QPSK
+        pass
+
+    def encode_ofdm_symbol(self, bits):
+        if bits.size != self.bits_per_ofdm_symbol:
+            raise Exception("Number of bits passed does not match number of bits per OFDM symbol!")
+
+        bit_idx = 0
+        enc_stream_upper_f = []
+        for mod in self.mod_schemes:
+            num_bits = self.modulators[mod].N
+            a = bits[bit_idx:bit_idx + num_bits]
+            alph_sym = self.modulators[mod].forward(bits[bit_idx:bit_idx + num_bits]).rect[0]
+            bit_idx += num_bits
+            sym = alph_sym[0] + 1j * alph_sym[1]
+            enc_stream_upper_f.append(sym)
+
+        if len(enc_stream_upper_f) != self.num_channels:
+            raise Exception("An error occurred! Computed upper sideband contains more channels than specified")
+
+        enc_stream_upper_f = np.asarray(enc_stream_upper_f)
+        enc_stream_lower_f = np.conjugate(np.flip(enc_stream_upper_f))
+        enc_stream_f = np.concatenate((self.dc_offset, enc_stream_upper_f, enc_stream_lower_f), axis=None)
+        enc_stream_t = np.fft.ifft(enc_stream_f)
+        return np.real(enc_stream_t)
+
+    def encode_bits(self, bit_stream):
+        if bit_stream.size % self.bits_per_ofdm_symbol != 0:
+            raise Exception("Bit stream size is not an integer multiple of bits per OFDM symbol!")
+
+        self.tx_bits = bit_stream
+
+        bit_blocks = bit_stream.reshape(-1, self.bits_per_ofdm_symbol)
+        tx_stream = []
+        for bits in bit_blocks:
+            tx_stream.append(self.encode_ofdm_symbol(bits))
+        tx_stream = np.asarray(tx_stream).flatten()
+        return tx_stream
+
+    def decode_ofdm_symbol(self, stream):
+        if stream.size != (2*self.num_channels + 1):
+            raise Exception("Number of samples passed is not equal to number of samples in one OFDM symbol!")
+
+        stream_f = np.fft.fft(stream)
+        stream_upper_f = stream_f[1:self.num_channels+1]
+
+        bits = []
+        for chan, mod in zip(stream_upper_f, self.mod_schemes):
+            rf_sym = RFSignal(np.zeros((3, 3)))
+            rf_sym.set_rect_xy(np.real(chan), np.imag(chan))
+            bits.append(self.demodulators[mod].forward(rf_sym).astype(int))
+        bits = np.asarray(bits).flatten()
+        return bits
+
+    def decode_bits(self, rx_stream):
+        if rx_stream.size % (2*self.num_channels + 1) != 0:
+            raise Exception("Number of samples passed is not an integer multiple of the samples per OFDM symbol")
+        pass
+
+        ofdm_symbols = rx_stream.reshape(-1, (2*self.num_channels + 1))
+        rx_bit_stream = []
+        for sym in ofdm_symbols:
+            rx_bit_stream.append(self.decode_ofdm_symbol(sym))
+        rx_bit_stream = np.asarray(rx_bit_stream).flatten()
+        return rx_bit_stream
+
+    def computeBER(self, n_ofdm_symbols):
+        bit_stream = np.random.randint(2, size=n_ofdm_symbols * self.bits_per_ofdm_symbol)
+        tx_stream = self.encode_bits(bit_stream)
+        rx_stream = self.transmit_and_equalize(tx_stream)
+
+        tx_stream_f = np.fft.fft(tx_stream)
+        rx_stream_f = np.fft.fft(rx_stream)
+        f = np.fft.fftfreq(tx_stream_f.shape[0], d=1/(self.channel_args["fs"]))
+
+        plt.plot(f, np.real(tx_stream_f), 'x')
+        plt.plot(f, np.real(rx_stream_f), 'x')
+        plt.ylim([-5, 5])
+        plt.show()
+
+        # plt.plot(tx_stream)
+        # plt.plot(rx_stream)
+        # plt.show()
+
+        rx_bit_stream = self.decode_bits(rx_stream)
+        ber = np.sum(np.equal(bit_stream, rx_bit_stream).astype(int))/bit_stream.size
+        print(ber)
+        pass
+
+
+if __name__ == '__main__':
+    BANDWIDTH = 100e9
+    channel_args = {"fs": BANDWIDTH,
+                    "dispersion_factor": -21.7 * 1e-24,
+                    "fiber_length": 10,
+                    "fiber_length_stddev": 0,
+                    "lpf_cutoff": BANDWIDTH / 2,
+                    "rx_stddev": 0.01,
+                    "sig_avg": 0.5,
+                    "enob": 10}
+
+    # TODO : For High Cardinality or number of OFDM Symbols the decoder fails and BER ~ 0.5
+
+    ofdm_obj = OFDM(64, channel_args)
+    ofdm_obj.computeBER(100)
+
+    pass

+ 134 - 57
tests/ofdm_test.py

@@ -4,24 +4,31 @@ import numpy as np
 from models.custom_layers import OpticalChannel
 from matplotlib import pyplot as plt
 
-BANDWIDTH = 64e9
+# mapping_table = np.asarray([-3 - 3j, -3 - 1j, -3 + 3j, -3 + 1j,
+#                             -1 - 3j, -1 - 1j, -1 + 3j, -1 + 1j,
+#                             3 - 3j, 3 - 1j, 3 + 3j, 3 + 1j,
+#                             1 - 3j, 1 - 1j, 1 + 3j, 1 + 1j])
 
-CARDINALITY = 16
+mapping_table = np.asarray([-1 - 1j, -1 + 1j, 1 - 1j, 1 + 1j])
+
+BANDWIDTH = 50e9
+
+CARDINALITY = mapping_table.shape[0]
 DISPERSION_FACTOR = -21.7 * 1e-24
-FIBER_LENGTH = 50
+FIBER_LENGTH = 5
 FIBER_LENGTH_STDDEV = 0
 RX_STDDEV = 0.01
 SIG_AVG = 0.5
 ENOB = 10
 
 # Number of OFDM symbols to simulate
-OFDM_N = 50
+OFDM_N = 100
 # Number of OFDM subcarriers
-K = 16
+K = 32
 # length of the cyclic prefix: 25% of the block
 CP = K // 4
 # number of pilot carriers per OFDM block
-P = 8
+P = 0
 # The known value each pilot transmits
 pilotValue = 3 + 3j
 # DC offset used to ensure all values are positive
@@ -30,16 +37,12 @@ DC_OFFSET = np.asarray([DC_OFFSET])
 
 SHOW_PLOTS = True
 
-mapping_table = np.asarray([-3 - 3j, -3 - 1j, -3 + 3j, -3 + 1j,
-                            -1 - 3j, -1 - 1j, -1 + 3j, -1 + 1j,
-                            3 - 3j, 3 - 1j, 3 + 3j, 3 + 1j,
-                            1 - 3j, 1 - 1j, 1 + 3j, 1 + 1j])
 
 bits_per_symbol = int(math.log(CARDINALITY, 2))
 bits_lst = [list(i) for i in itertools.product([0, 1], repeat=bits_per_symbol)]
 
 # Set true to view plot of symbols as IQ plot
-if SHOW_PLOTS:
+if False:
     for idx, sym in enumerate(mapping_table):
         plt.plot(sym.real, sym.imag, 'bo')
         plt.text(sym.real, sym.imag + 0.2, str(bits_lst[idx])[1:-1], ha='center')
@@ -51,71 +54,145 @@ if SHOW_PLOTS:
     plt.ylim(-4, 4)
     plt.show()
 
+tx_stream = []
+tx_syms = []
 
-# All subcarriers used
-allCarriers = np.arange(K)
+for ofdm_sym in range(OFDM_N):
+    print(ofdm_sym)
+    # All subcarriers used
+    allCarriers = np.arange(K)
 
-# Identifying pilot carriers (and adding final subcarrier as a pilot for convenience)
-pilotCarriers = allCarriers[::K // P]
-pilotCarriers = np.hstack([pilotCarriers, np.array([allCarriers[-1]])])
-P = P + 1
+    # Identifying pilot carriers (and adding final subcarrier as a pilot for convenience)
+    if P == 0:
+        pilotCarriers = np.asarray([-1])
+    else:
+        pilotCarriers = allCarriers[::K // P]
+        pilotCarriers = np.hstack([pilotCarriers, np.array([allCarriers[-1]])])
+        # P = P + 1
 
-# Removing pilot carriers to obtain data carriers
-dataCarriers = np.delete(allCarriers, pilotCarriers)
+    # Removing pilot carriers to obtain data carriers
+    dataCarriers = np.delete(allCarriers, pilotCarriers)
 
-if SHOW_PLOTS:
-    plt.plot(pilotCarriers, np.zeros_like(pilotCarriers), 'bo', label='pilot')
-    plt.plot(dataCarriers, np.zeros_like(dataCarriers), 'ro', label='data')
-    plt.show()
+    # if SHOW_PLOTS:
+    #     f = np.fft.fftfreq(allCarriers.shape[0], d=1 / BANDWIDTH)
+    #     plt.plot(f[1]*pilotCarriers, np.zeros_like(pilotCarriers), 'bo', label='pilot')
+    #     plt.plot(f[1]*dataCarriers, np.zeros_like(dataCarriers), 'ro', label='data')
+    #     plt.show()
 
-# Generate random symbols as integers and then map symbol values onto them
-input_syms = np.random.randint(CARDINALITY, size=len(dataCarriers))
-mapped_syms = mapping_table[input_syms]
+    # Generate random symbols as integers and then map symbol values onto them
+    input_syms = np.random.randint(CARDINALITY, size=len(dataCarriers))
+    tx_syms.append(input_syms)
+    mapped_syms = mapping_table[input_syms]
 
-# Generate the upper sideband of the OFDM symbol
-enc_stream_upper_f = np.zeros(K, dtype=complex)
-enc_stream_upper_f[pilotCarriers] = pilotValue
-enc_stream_upper_f[dataCarriers] = mapped_syms
+    # Generate the upper sideband of the OFDM symbol
+    enc_stream_upper_f = np.zeros(K, dtype=complex)
+    enc_stream_upper_f[pilotCarriers] = pilotValue
+    enc_stream_upper_f[dataCarriers] = mapped_syms
 
-# Generate the lower sideband of the OFDM symbol
-enc_stream_lower_f = np.conjugate(np.flip(enc_stream_upper_f))
+    # Generate the lower sideband of the OFDM symbol
+    enc_stream_lower_f = np.conjugate(np.flip(enc_stream_upper_f))
 
-# Combine the two sidebands with a DC offset to ensure values are always positive
-enc_stream_f = np.concatenate((DC_OFFSET, enc_stream_upper_f, enc_stream_lower_f), axis=None)
+    # Combine the two sidebands with a DC offset to ensure values are always positive
+    enc_stream_f = np.concatenate((DC_OFFSET, enc_stream_upper_f, enc_stream_lower_f), axis=None)
 
-if SHOW_PLOTS:
-    f = np.fft.fftfreq(enc_stream_f.shape[0], d=1/BANDWIDTH)
-    plt.plot(f, np.real(enc_stream_f), 'x')
-    plt.plot(f, np.imag(enc_stream_f), 'x')
-    # plt.xlim(-0.5e10, 0.5e10)
-    plt.show()
+    # if SHOW_PLOTS:
+    #     f = np.fft.fftfreq(enc_stream_f.shape[0], d=1/BANDWIDTH)
+    #     plt.plot(f, np.real(enc_stream_f), 'x')
+    #     plt.plot(f, np.imag(enc_stream_f), 'x')
+    #     # plt.xlim(-0.5e10, 0.5e10)
+    #     plt.show()
 
-# Take the inverse fourier transform
-enc_stream_t = np.fft.ifft(enc_stream_f)
+    # Take the inverse fourier transform
+    enc_stream_t = np.fft.ifft(enc_stream_f)
 
-if SHOW_PLOTS:
-    t = np.arange(len(enc_stream_t))*(1/BANDWIDTH)
-    plt.plot(t, np.real(enc_stream_t))
-    plt.plot(t, np.imag(enc_stream_t))
-    plt.show()
+    # if SHOW_PLOTS:
+    #     t = np.arange(len(enc_stream_t))*(1/BANDWIDTH)
+    #     plt.plot(t, np.real(enc_stream_t))
+    #     plt.plot(t, np.imag(enc_stream_t))
+    #     plt.show()
+
+    # Take the real part to be transmitted via the channel
+    tx_stream.append(np.real(enc_stream_t))
 
-# Take the real part to be transmitted via the channel
-tx_stream = np.real(enc_stream_t)
+
+tx_stream = np.asarray(tx_stream).flatten()
+tx_syms = np.asarray(tx_syms).flatten()
 
 optical_channel = OpticalChannel(fs=BANDWIDTH,
-                                 num_of_samples=len(tx_stream),  # TODO: determine size of input to channel
                                  dispersion_factor=DISPERSION_FACTOR,
                                  fiber_length=FIBER_LENGTH,
                                  fiber_length_stddev=FIBER_LENGTH_STDDEV,
                                  lpf_cutoff=BANDWIDTH / 2,
                                  rx_stddev=RX_STDDEV,
                                  sig_avg=SIG_AVG,
-                                 enob=ENOB)
+                                 enob=ENOB,
+                                 num_of_samples=len(tx_stream))
+
+rx_stream = optical_channel(tx_stream)
+
+rx_stream_f = np.fft.fft(np.sqrt(rx_stream))
+freq = np.fft.fftfreq(rx_stream_f.shape[0], d=1/BANDWIDTH)
+# compensation = np.exp(-0.5j*DISPERSION_FACTOR*FIBER_LENGTH*np.square(2*math.pi*freq))
+
+# rx_compensated_f = rx_stream_f * compensation
+
+rx_compensated_t = np.fft.ifft(rx_stream_f)
+
+rx_reshaped_t = rx_compensated_t.reshape((OFDM_N, -1))
+
+if False:
+    t = np.arange(len(tx_stream))*(1/BANDWIDTH)
+    plt.plot(t, tx_stream)
+    plt.plot(t, 2+np.sqrt(rx_stream))
+    plt.plot(t, rx_compensated_t-2)
+    plt.show()
+
+rx_syms = []
+
+# chan_num = 2
+
+for chan_num in range(K):
+    i = 0
+    for sym in rx_reshaped_t:
+        sym_f = np.fft.fft(sym)
+
+        freq = np.fft.fftfreq(sym_f.shape[0], d=1 / BANDWIDTH)
+        # compensation = np.exp(-0.5j * DISPERSION_FACTOR * FIBER_LENGTH * np.square(2 * math.pi * freq))
+
+        sym_f = sym_f
+
+        upper_f = sym_f[1:1 + K]
+
+        if SHOW_PLOTS:
+            # for x in upper_f:
+            if tx_syms[i+chan_num] == 0:
+                plt.plot(np.real(upper_f[chan_num]), np.imag(upper_f[chan_num]), 'bo')
+            elif tx_syms[i+chan_num] == 1:
+                plt.plot(np.real(upper_f[chan_num]), np.imag(upper_f[chan_num]), 'ro')
+            elif tx_syms[i+chan_num] == 2:
+                plt.plot(np.real(upper_f[chan_num]), np.imag(upper_f[chan_num]), 'go')
+            elif tx_syms[i+chan_num] == 3:
+                plt.plot(np.real(upper_f[chan_num]), np.imag(upper_f[chan_num]), 'mo')
+            i += K
+    plt.title((str(chan_num) + " rot"))
+    plt.show()
 
-# rx_stream = optical_channel(enc_stream_t)
+# for sym in rx_reshaped_t:
+#     sym_f = np.fft.fft(sym)
+#     upper_f = sym_f[1:1 + K]
 #
-# if SHOW_PLOTS:
-#     t = np.arange(len(tx_stream))*(1/BANDWIDTH)
-#     plt.plot(t, tx_stream)
-#     plt.plot(t, rx_stream)
-#     plt.show()
+#     if SHOW_PLOTS:
+#         for x in upper_f:
+#             if tx_syms[i] == 0:
+#                 plt.plot(np.real(x), np.imag(x), 'bo')
+#             elif tx_syms[i] == 1:
+#                 plt.plot(np.real(x), np.imag(x), 'ro')
+#             elif tx_syms[i] == 2:
+#                 plt.plot(np.real(x), np.imag(x), 'go')
+#             elif tx_syms[i] == 3:
+#                 plt.plot(np.real(x), np.imag(x), 'mo')
+#             # plt.plot(np.real(mapping_table[tx_syms[i]]), np.imag(mapping_table[tx_syms[i]]), 'ro')
+#             # plt.text(np.real(sym), np.imag(sym), str(i), ha='center')
+#             # plt.text(np.real(mapping_table[tx_syms[i]]), np.imag(mapping_table[tx_syms[i]]), str(i), ha='center')
+#             i += 1
+plt.show()