min_test.py 16 KB


  1. """
  2. These are some unstructured tests. Feel free to use this code for anything else
  3. """
  4. import logging
  5. import pathlib
  6. from itertools import chain
  7. from sys import stdout
  8. from tensorflow.python.framework.errors_impl import NotFoundError
  9. import defs
  10. from graphs import get_SNR, get_AWGN_ber
  11. from models import basic
  12. from models.autoencoder import Autoencoder, view_encoder
  13. import matplotlib.pyplot as plt
  14. import tensorflow as tf
  15. import misc
  16. import numpy as np
  17. from models.basic import AlphabetDemod, AlphabetMod
  18. from models.optical_channel import OpticalChannel
  19. from models.quantized_net import QuantizedNeuralNetwork
  20. def _test_optics_autoencoder():
  21. ch = OpticalChannel(
  22. noise_level=-10,
  23. dispersion=-21.7,
  24. symbol_rate=10e9,
  25. sample_rate=400e9,
  26. length=10,
  27. pulse_shape='rcos',
  28. sqrt_out=True
  29. )
  30. tf.executing_eagerly()
  31. aenc = Autoencoder(4, channel=ch)
  32. aenc.train(samples=1e6)
  33. plt.plot(*get_SNR(
  34. aenc.get_modulator(),
  35. aenc.get_demodulator(),
  36. ber_func=get_AWGN_ber,
  37. samples=100000,
  38. steps=50,
  39. start=-5,
  40. stop=15
  41. ), '-', label='AE')
  42. plt.plot(*get_SNR(
  43. AlphabetMod('4pam', 10e6),
  44. AlphabetDemod('4pam', 10e6),
  45. samples=30000,
  46. steps=50,
  47. start=-5,
  48. stop=15,
  49. length=1,
  50. pulse_shape='rcos'
  51. ), '-', label='4PAM')
  52. plt.yscale('log')
  53. plt.grid()
  54. plt.xlabel('SNR dB')
  55. plt.title("Autoencoder Performance")
  56. plt.legend()
  57. plt.savefig('optics_autoencoder.eps', format='eps')
  58. plt.show()
  59. def _test_autoencoder_pretrain():
  60. # aenc = Autoencoder(4, -25)
  61. # aenc.train(samples=1e6)
  62. # plt.plot(*get_SNR(
  63. # aenc.get_modulator(),
  64. # aenc.get_demodulator(),
  65. # ber_func=get_AWGN_ber,
  66. # samples=100000,
  67. # steps=50,
  68. # start=-5,
  69. # stop=15
  70. # ), '-', label='Random AE')
  71. aenc = Autoencoder(4, -25)
  72. # aenc.fit_encoder('16qam', 3e4)
  73. aenc.fit_decoder('16qam', 1e5)
  74. plt.plot(*get_SNR(
  75. aenc.get_modulator(),
  76. aenc.get_demodulator(),
  77. ber_func=get_AWGN_ber,
  78. samples=100000,
  79. steps=50,
  80. start=-5,
  81. stop=15
  82. ), '-', label='16QAM Pre-trained AE')
  83. aenc.train(samples=3e6)
  84. plt.plot(*get_SNR(
  85. aenc.get_modulator(),
  86. aenc.get_demodulator(),
  87. ber_func=get_AWGN_ber,
  88. samples=100000,
  89. steps=50,
  90. start=-5,
  91. stop=15
  92. ), '-', label='16QAM Post-trained AE')
  93. plt.plot(*get_SNR(
  94. AlphabetMod('16qam', 10e6),
  95. AlphabetDemod('16qam', 10e6),
  96. ber_func=get_AWGN_ber,
  97. samples=100000,
  98. steps=50,
  99. start=-5,
  100. stop=15
  101. ), '-', label='16QAM')
  102. plt.yscale('log')
  103. plt.grid()
  104. plt.xlabel('SNR dB')
  105. plt.title("4Bit Autoencoder Performance")
  106. plt.legend()
  107. plt.show()
  108. class LiteTFMod(defs.Modulator):
  109. def __init__(self, name, autoencoder):
  110. super().__init__(2 ** autoencoder.N)
  111. self.autoencoder = autoencoder
  112. tflite_models_dir = pathlib.Path("/tmp/tflite/")
  113. tflite_model_file = tflite_models_dir / (name + ".tflite")
  114. self.interpreter = tf.lite.Interpreter(model_path=str(tflite_model_file))
  115. self.interpreter.allocate_tensors()
  116. pass
  117. def forward(self, binary: np.ndarray):
  118. reshaped = binary.reshape((-1, (self.N * self.autoencoder.parallel)))
  119. reshaped_ho = misc.bit_matrix2one_hot(reshaped)
  120. input_index = self.interpreter.get_input_details()[0]["index"]
  121. input_dtype = self.interpreter.get_input_details()[0]["dtype"]
  122. input_shape = self.interpreter.get_input_details()[0]["shape"]
  123. output_index = self.interpreter.get_output_details()[0]["index"]
  124. output_shape = self.interpreter.get_output_details()[0]["shape"]
  125. x = np.zeros((len(reshaped_ho), output_shape[1]))
  126. for i, ho in enumerate(reshaped_ho):
  127. self.interpreter.set_tensor(input_index, ho.reshape(input_shape).astype(input_dtype))
  128. self.interpreter.invoke()
  129. x[i] = self.interpreter.get_tensor(output_index)
  130. if self.autoencoder.bipolar:
  131. x = x * 2 - 1
  132. if self.autoencoder.parallel > 1:
  133. x = x.reshape((-1, self.autoencoder.signal_dim))
  134. f = np.zeros(x.shape[0])
  135. if self.autoencoder.signal_dim <= 1:
  136. p = np.zeros(x.shape[0])
  137. else:
  138. p = x[:, 1]
  139. x3 = misc.rect2polar(np.c_[x[:, 0], p, f])
  140. return basic.RFSignal(x3)
  141. class LiteTFDemod(defs.Demodulator):
  142. def __init__(self, name, autoencoder):
  143. super().__init__(2 ** autoencoder.N)
  144. self.autoencoder = autoencoder
  145. tflite_models_dir = pathlib.Path("/tmp/tflite/")
  146. tflite_model_file = tflite_models_dir / (name + ".tflite")
  147. self.interpreter = tf.lite.Interpreter(model_path=str(tflite_model_file))
  148. self.interpreter.allocate_tensors()
  149. def forward(self, values: defs.Signal) -> np.ndarray:
  150. if self.autoencoder.signal_dim <= 1:
  151. val = values.rect_x
  152. else:
  153. val = values.rect
  154. if self.autoencoder.parallel > 1:
  155. val = val.reshape((-1, self.autoencoder.parallel))
  156. input_index = self.interpreter.get_input_details()[0]["index"]
  157. input_dtype = self.interpreter.get_input_details()[0]["dtype"]
  158. input_shape = self.interpreter.get_input_details()[0]["shape"]
  159. output_index = self.interpreter.get_output_details()[0]["index"]
  160. output_shape = self.interpreter.get_output_details()[0]["shape"]
  161. decoded = np.zeros((len(val), output_shape[1]))
  162. for i, v in enumerate(val):
  163. self.interpreter.set_tensor(input_index, v.reshape(input_shape).astype(input_dtype))
  164. self.interpreter.invoke()
  165. decoded[i] = self.interpreter.get_tensor(output_index)
  166. result = misc.int2bit_array(decoded.argmax(axis=1), self.N * self.autoencoder.parallel)
  167. return result.reshape(-1, )
  168. def _test_autoencoder_perf():
  169. assert float(tf.__version__[:3]) >= 2.3
  170. # aenc = Autoencoder(3, -15)
  171. # aenc.train(samples=1e6)
  172. # plt.plot(*get_SNR(
  173. # aenc.get_modulator(),
  174. # aenc.get_demodulator(),
  175. # ber_func=get_AWGN_ber,
  176. # samples=100000,
  177. # steps=50,
  178. # start=-5,
  179. # stop=15
  180. # ), '-', label='3Bit AE')
  181. # aenc = Autoencoder(4, -25, bipolar=True, dtype=tf.float64)
  182. # aenc.train(samples=5e5)
  183. # plt.plot(*get_SNR(
  184. # aenc.get_modulator(),
  185. # aenc.get_demodulator(),
  186. # ber_func=get_AWGN_ber,
  187. # samples=100000,
  188. # steps=50,
  189. # start=-5,
  190. # stop=15
  191. # ), '-', label='4Bit AE F64')
  192. aenc = Autoencoder(4, -25, bipolar=True)
  193. aenc.train(epoch_size=1e3, epochs=10)
  194. # #
  195. m = aenc.N * aenc.parallel
  196. x_train = misc.bit_matrix2one_hot(misc.generate_random_bit_array(100*m).reshape((-1, m)))
  197. x_train_enc = aenc.encoder(x_train)
  198. x_train = tf.cast(x_train, tf.float32)
  199. #
  200. # plt.plot(*get_SNR(
  201. # aenc.get_modulator(),
  202. # aenc.get_demodulator(),
  203. # ber_func=get_AWGN_ber,
  204. # samples=100000,
  205. # steps=50,
  206. # start=-5,
  207. # stop=15
  208. # ), '-', label='4AE F32')
  209. # # #
  210. def save_tfline(model, name, types=None, ops=None, io_types=None, train_x=None):
  211. converter = tf.lite.TFLiteConverter.from_keras_model(model)
  212. if types is not None:
  213. converter.optimizations = [tf.lite.Optimize.DEFAULT]
  214. converter.target_spec.supported_types = types
  215. if ops is not None:
  216. converter.optimizations = [tf.lite.Optimize.DEFAULT]
  217. converter.target_spec.supported_ops = ops
  218. if io_types is not None:
  219. converter.inference_input_type = io_types
  220. converter.inference_output_type = io_types
  221. if train_x is not None:
  222. def representative_data_gen():
  223. for input_value in tf.data.Dataset.from_tensor_slices(train_x).batch(1).take(100):
  224. yield [input_value]
  225. converter.representative_dataset = representative_data_gen
  226. tflite_model = converter.convert()
  227. tflite_models_dir = pathlib.Path("/tmp/tflite/")
  228. tflite_models_dir.mkdir(exist_ok=True, parents=True)
  229. tflite_model_file = tflite_models_dir / (name + ".tflite")
  230. tflite_model_file.write_bytes(tflite_model)
  231. print("Saving models")
  232. save_tfline(aenc.encoder, "default_enc")
  233. save_tfline(aenc.decoder, "default_dec")
  234. #
  235. # save_tfline(aenc.encoder, "float16_enc", [tf.float16])
  236. # save_tfline(aenc.decoder, "float16_dec", [tf.float16])
  237. #
  238. # save_tfline(aenc.encoder, "bfloat16_enc", [tf.bfloat16])
  239. # save_tfline(aenc.decoder, "bfloat16_dec", [tf.bfloat16])
  240. INT16X8 = tf.lite.OpsSet.EXPERIMENTAL_TFLITE_BUILTINS_ACTIVATIONS_INT16_WEIGHTS_INT8
  241. save_tfline(aenc.encoder, "int16x8_enc", ops=[INT16X8], train_x=x_train)
  242. save_tfline(aenc.decoder, "int16x8_dec", ops=[INT16X8], train_x=x_train_enc)
  243. # save_tfline(aenc.encoder, "int8_enc", ops=[tf.lite.OpsSet.TFLITE_BUILTINS_INT8], io_types=tf.uint8, train_x=x_train)
  244. # save_tfline(aenc.decoder, "int8_dec", ops=[tf.lite.OpsSet.TFLITE_BUILTINS_INT8], io_types=tf.uint8, train_x=x_train_enc)
  245. print("Testing BER vs SNR")
  246. plt.plot(*get_SNR(
  247. LiteTFMod("default_enc", aenc),
  248. LiteTFDemod("default_dec", aenc),
  249. ber_func=get_AWGN_ber,
  250. samples=100000,
  251. steps=50,
  252. start=-5,
  253. stop=15
  254. ), '-', label='4AE F32')
  255. # plt.plot(*get_SNR(
  256. # LiteTFMod("float16_enc", aenc),
  257. # LiteTFDemod("float16_dec", aenc),
  258. # ber_func=get_AWGN_ber,
  259. # samples=100000,
  260. # steps=50,
  261. # start=-5,
  262. # stop=15
  263. # ), '-', label='4AE F16')
  264. # #
  265. # plt.plot(*get_SNR(
  266. # LiteTFMod("bfloat16_enc", aenc),
  267. # LiteTFDemod("bfloat16_dec", aenc),
  268. # ber_func=get_AWGN_ber,
  269. # samples=100000,
  270. # steps=50,
  271. # start=-5,
  272. # stop=15
  273. # ), '-', label='4AE BF16')
  274. #
  275. plt.plot(*get_SNR(
  276. LiteTFMod("int16x8_enc", aenc),
  277. LiteTFDemod("int16x8_dec", aenc),
  278. ber_func=get_AWGN_ber,
  279. samples=100000,
  280. steps=50,
  281. start=-5,
  282. stop=15
  283. ), '-', label='4AE I16x8')
  284. # plt.plot(*get_SNR(
  285. # AlphabetMod('16qam', 10e6),
  286. # AlphabetDemod('16qam', 10e6),
  287. # ber_func=get_AWGN_ber,
  288. # samples=100000,
  289. # steps=50,
  290. # start=-5,
  291. # stop=15,
  292. # ), '-', label='16qam')
  293. plt.yscale('log')
  294. plt.grid()
  295. plt.xlabel('SNR dB')
  296. plt.ylabel('BER')
  297. plt.title("Autoencoder with different precision data types")
  298. plt.legend()
  299. plt.savefig('autoencoder_compression.eps', format='eps')
  300. plt.show()
  301. view_encoder(aenc.encoder, 4)
  302. # aenc = Autoencoder(5, -25)
  303. # aenc.train(samples=2e6)
  304. # plt.plot(*get_SNR(
  305. # aenc.get_modulator(),
  306. # aenc.get_demodulator(),
  307. # ber_func=get_AWGN_ber,
  308. # samples=100000,
  309. # steps=50,
  310. # start=-5,
  311. # stop=15
  312. # ), '-', label='5Bit AE')
  313. #
  314. # aenc = Autoencoder(6, -25)
  315. # aenc.train(samples=2e6)
  316. # plt.plot(*get_SNR(
  317. # aenc.get_modulator(),
  318. # aenc.get_demodulator(),
  319. # ber_func=get_AWGN_ber,
  320. # samples=100000,
  321. # steps=50,
  322. # start=-5,
  323. # stop=15
  324. # ), '-', label='6Bit AE')
  325. #
  326. # for scheme in ['64qam', '32qam', '16qam', 'qpsk', '8psk']:
  327. # plt.plot(*get_SNR(
  328. # AlphabetMod(scheme, 10e6),
  329. # AlphabetDemod(scheme, 10e6),
  330. # ber_func=get_AWGN_ber,
  331. # samples=100000,
  332. # steps=50,
  333. # start=-5,
  334. # stop=15,
  335. # ), '-', label=scheme.upper())
  336. #
  337. # plt.yscale('log')
  338. # plt.grid()
  339. # plt.xlabel('SNR dB')
  340. # plt.title("Autoencoder vs defined modulations")
  341. # plt.legend()
  342. # plt.show()
  343. def _test_autoencoder_perf2():
  344. aenc = Autoencoder(2, -20)
  345. aenc.train(samples=3e6)
  346. plt.plot(*get_SNR(aenc.get_modulator(), aenc.get_demodulator(), ber_func=get_AWGN_ber, samples=100000, steps=50, start=-5, stop=15), '-', label='2Bit AE')
  347. aenc = Autoencoder(3, -20)
  348. aenc.train(samples=3e6)
  349. plt.plot(*get_SNR(aenc.get_modulator(), aenc.get_demodulator(), ber_func=get_AWGN_ber, samples=100000, steps=50, start=-5, stop=15), '-', label='3Bit AE')
  350. aenc = Autoencoder(4, -20)
  351. aenc.train(samples=3e6)
  352. plt.plot(*get_SNR(aenc.get_modulator(), aenc.get_demodulator(), ber_func=get_AWGN_ber, samples=100000, steps=50, start=-5, stop=15), '-', label='4Bit AE')
  353. aenc = Autoencoder(5, -20)
  354. aenc.train(samples=3e6)
  355. plt.plot(*get_SNR(aenc.get_modulator(), aenc.get_demodulator(), ber_func=get_AWGN_ber, samples=100000, steps=50, start=-5, stop=15), '-', label='5Bit AE')
  356. for a in ['qpsk', '8psk', '16qam', '32qam', '64qam']:
  357. try:
  358. plt.plot(*get_SNR(AlphabetMod(a, 10e6), AlphabetDemod(a, 10e6), ber_func=get_AWGN_ber, samples=100000, steps=50, start=-5, stop=15,), '-', label=a.upper())
  359. except KeyboardInterrupt:
  360. break
  361. except Exception:
  362. pass
  363. plt.yscale('log')
  364. plt.grid()
  365. plt.xlabel('SNR dB')
  366. plt.title("Autoencoder vs defined modulations")
  367. plt.legend()
  368. plt.savefig('autoencoder_mods.eps', format='eps')
  369. plt.show()
  370. # view_encoder(aenc.encoder, 2)
  371. def _test_autoencoder_perf_qnn():
  372. fh = logging.FileHandler("model_quantizing.log", mode="w+")
  373. fh.setLevel(logging.INFO)
  374. sh = logging.StreamHandler(stream=stdout)
  375. sh.setLevel(logging.INFO)
  376. logger = logging.getLogger(__name__)
  377. logger.setLevel(level=logging.INFO)
  378. logger.addHandler(fh)
  379. logger.addHandler(sh)
  380. aenc = Autoencoder(4, -25, bipolar=True)
  381. # aenc.encoder.save_weights('ae_enc.bin')
  382. # aenc.decoder.save_weights('ae_dec.bin')
  383. # aenc.encoder.load_weights('ae_enc.bin')
  384. # aenc.decoder.load_weights('ae_dec.bin')
  385. try:
  386. aenc.load_weights('autoencoder')
  387. except NotFoundError:
  388. aenc.train(epoch_size=1e3, epochs=10)
  389. aenc.save_weights('autoencoder')
  390. aenc.compile(optimizer='adam', loss=tf.losses.MeanSquaredError())
  391. m = aenc.N * aenc.parallel
  392. view_encoder(aenc.encoder, 4, title="FP32 Alphabet")
  393. batch_size = 25000
  394. x_train = misc.bit_matrix2one_hot(misc.generate_random_bit_array(batch_size*m).reshape((-1, m)))
  395. x_test = misc.bit_matrix2one_hot(misc.generate_random_bit_array(5000*m).reshape((-1, m)))
  396. bits = [np.log2(i) for i in (32,)][0]
  397. alphabet_scalars = 2 # [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
  398. num_layers = sum([layer.__class__.__name__ in ('Dense',) for layer in aenc.all_layers])
  399. # for b in (3, 6, 8, 12, 16, 24, 32, 48, 64):
  400. get_data = (sample for sample in x_train)
  401. for i in range(num_layers):
  402. get_data = chain(get_data, (sample for sample in x_train))
  403. qnn = QuantizedNeuralNetwork(
  404. network=aenc,
  405. batch_size=batch_size,
  406. get_data=get_data,
  407. logger=logger,
  408. bits=np.log2(16),
  409. alphabet_scalar=alphabet_scalars,
  410. )
  411. qnn.quantize_network()
  412. qnn.quantized_net.compile(optimizer='adam', loss=tf.keras.losses.MeanSquaredError())
  413. view_encoder(qnn.quantized_net, 4, title=f"Quantised {16}b alphabet")
  414. # view_encoder(qnn_enc.quantized_net, 4, title=f"Quantised {b}b alphabet")
  415. # _, q_accuracy = qnn.quantized_net.evaluate(x_test, x_test, verbose=True)
  416. pass
  417. if __name__ == '__main__':
  418. # plt.plot(*get_SNR(
  419. # AlphabetMod('16qam', 10e6),
  420. # AlphabetDemod('16qam', 10e6),
  421. # ber_func=get_AWGN_ber,
  422. # samples=100000,
  423. # steps=50,
  424. # start=-5,
  425. # stop=15,
  426. # ), '-', label='16qam AWGN')
  427. #
  428. # plt.plot(*get_SNR(
  429. # AlphabetMod('16qam', 10e6),
  430. # AlphabetDemod('16qam', 10e6),
  431. # samples=100000,
  432. # steps=50,
  433. # start=-5,
  434. # stop=15,
  435. # ), '-', label='16qam OPTICAL')
  436. #
  437. # plt.yscale('log')
  438. # plt.grid()
  439. # plt.xlabel('SNR dB')
  440. # plt.show()
  441. # _test_autoencoder_perf()
  442. _test_autoencoder_perf_qnn()
  443. # _test_autoencoder_perf2()
  444. # _test_autoencoder_pretrain()
  445. # _test_optics_autoencoder()