| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161 |
- import numpy as np
- import tensorflow as tf
- from tensorflow.keras import layers, losses
- from matplotlib import pyplot as plt
- from models.basic import AlphabetMod, AlphabetDemod, RFSignal
- from models.custom_layers import OpticalChannel
- class OFDM():
- def __init__(self, num_channels, channel_args, **kwargs):
- self.num_channels = num_channels
- self.channel_args = channel_args
- if "dc_offset" in kwargs:
- self.dc_offset = kwargs["dc_offset"]
- else:
- self.dc_offset = 50
- self.mod_schemes = None
- self.modulators = None
- self.demodulators = None
- self.bits_per_ofdm_symbol = None
- self.set_mod_schemes()
- self.tx_bits = None
- def set_mod_schemes(self, new_mod_schemes=None):
- if new_mod_schemes is None:
- self.mod_schemes = ["qpsk" for x in range(self.num_channels)]
- else:
- self.mod_schemes = new_mod_schemes
- self.modulators = {x: AlphabetMod(x, 0.0) for x in list(set(self.mod_schemes))}
- self.demodulators = {x: AlphabetDemod(x, 0.0) for x in list(set(self.mod_schemes))}
- self.bits_per_ofdm_symbol = 0
- for mod in self.mod_schemes:
- self.bits_per_ofdm_symbol += self.modulators[mod].N
- def transmit_and_equalize(self, tx_stream):
- optical_channel = OpticalChannel(**self.channel_args, num_of_samples=len(tx_stream))
- rx_stream = optical_channel(tx_stream)
- rx_stream_f = np.fft.fft(np.sqrt(rx_stream))
- f = np.fft.fftfreq(rx_stream_f.shape[0], d=1/(self.channel_args["fs"]))
- # TODO : Equalization needs to be tested
- mul = np.exp(-0.5j*(self.channel_args["dispersion_factor"])*(self.channel_args["fiber_length"])*np.square(2*np.pi*f))
- rx_compensated_f = rx_stream_f * mul
- rx_compensated = np.abs(np.fft.ifft(rx_compensated_f))
- return rx_compensated
- def find_optimal_schemes(self):
- # TODO : Choosing approapriate modulation scheme after a single pass of QPSK
- pass
- def encode_ofdm_symbol(self, bits):
- if bits.size != self.bits_per_ofdm_symbol:
- raise Exception("Number of bits passed does not match number of bits per OFDM symbol!")
- bit_idx = 0
- enc_stream_upper_f = []
- for mod in self.mod_schemes:
- num_bits = self.modulators[mod].N
- a = bits[bit_idx:bit_idx + num_bits]
- alph_sym = self.modulators[mod].forward(bits[bit_idx:bit_idx + num_bits]).rect[0]
- bit_idx += num_bits
- sym = alph_sym[0] + 1j * alph_sym[1]
- enc_stream_upper_f.append(sym)
- if len(enc_stream_upper_f) != self.num_channels:
- raise Exception("An error occurred! Computed upper sideband contains more channels than specified")
- enc_stream_upper_f = np.asarray(enc_stream_upper_f)
- enc_stream_lower_f = np.conjugate(np.flip(enc_stream_upper_f))
- enc_stream_f = np.concatenate((self.dc_offset, enc_stream_upper_f, enc_stream_lower_f), axis=None)
- enc_stream_t = np.fft.ifft(enc_stream_f)
- return np.real(enc_stream_t)
- def encode_bits(self, bit_stream):
- if bit_stream.size % self.bits_per_ofdm_symbol != 0:
- raise Exception("Bit stream size is not an integer multiple of bits per OFDM symbol!")
- self.tx_bits = bit_stream
- bit_blocks = bit_stream.reshape(-1, self.bits_per_ofdm_symbol)
- tx_stream = []
- for bits in bit_blocks:
- tx_stream.append(self.encode_ofdm_symbol(bits))
- tx_stream = np.asarray(tx_stream).flatten()
- return tx_stream
- def decode_ofdm_symbol(self, stream):
- if stream.size != (2*self.num_channels + 1):
- raise Exception("Number of samples passed is not equal to number of samples in one OFDM symbol!")
- stream_f = np.fft.fft(stream)
- stream_upper_f = stream_f[1:self.num_channels+1]
- bits = []
- for chan, mod in zip(stream_upper_f, self.mod_schemes):
- rf_sym = RFSignal(np.zeros((3, 3)))
- rf_sym.set_rect_xy(np.real(chan), np.imag(chan))
- bits.append(self.demodulators[mod].forward(rf_sym).astype(int))
- bits = np.asarray(bits).flatten()
- return bits
- def decode_bits(self, rx_stream):
- if rx_stream.size % (2*self.num_channels + 1) != 0:
- raise Exception("Number of samples passed is not an integer multiple of the samples per OFDM symbol")
- pass
- ofdm_symbols = rx_stream.reshape(-1, (2*self.num_channels + 1))
- rx_bit_stream = []
- for sym in ofdm_symbols:
- rx_bit_stream.append(self.decode_ofdm_symbol(sym))
- rx_bit_stream = np.asarray(rx_bit_stream).flatten()
- return rx_bit_stream
- def computeBER(self, n_ofdm_symbols):
- bit_stream = np.random.randint(2, size=n_ofdm_symbols * self.bits_per_ofdm_symbol)
- tx_stream = self.encode_bits(bit_stream)
- rx_stream = self.transmit_and_equalize(tx_stream)
- tx_stream_f = np.fft.fft(tx_stream)
- rx_stream_f = np.fft.fft(rx_stream)
- f = np.fft.fftfreq(tx_stream_f.shape[0], d=1/(self.channel_args["fs"]))
- plt.plot(f, np.real(tx_stream_f), 'x')
- plt.plot(f, np.real(rx_stream_f), 'x')
- plt.ylim([-5, 5])
- plt.show()
- # plt.plot(tx_stream)
- # plt.plot(rx_stream)
- # plt.show()
- rx_bit_stream = self.decode_bits(rx_stream)
- ber = np.sum(np.equal(bit_stream, rx_bit_stream).astype(int))/bit_stream.size
- print(ber)
- pass
- if __name__ == '__main__':
- BANDWIDTH = 100e9
- channel_args = {"fs": BANDWIDTH,
- "dispersion_factor": -21.7 * 1e-24,
- "fiber_length": 10,
- "fiber_length_stddev": 0,
- "lpf_cutoff": BANDWIDTH / 2,
- "rx_stddev": 0.01,
- "sig_avg": 0.5,
- "enob": 10}
- # TODO : For High Cardinality or number of OFDM Symbols the decoder fails and BER ~ 0.5
- ofdm_obj = OFDM(64, channel_args)
- ofdm_obj.computeBER(100)
- pass
|