min_test.py 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519
  1. """
  2. These are some unstructured tests. Feel free to use this code for anything else
  3. """
  4. import logging
  5. import pathlib
  6. from itertools import chain
  7. from sys import stdout
  8. import defs
  9. from graphs import get_SNR, get_AWGN_ber
  10. from models import basic
  11. from models.autoencoder import Autoencoder, view_encoder
  12. import matplotlib.pyplot as plt
  13. import tensorflow as tf
  14. import misc
  15. import numpy as np
  16. from models.basic import AlphabetDemod, AlphabetMod
  17. from models.optical_channel import OpticalChannel
  18. from models.quantized_net import QuantizedNeuralNetwork
  19. def _test_optics_autoencoder():
  20. ch = OpticalChannel(
  21. noise_level=-10,
  22. dispersion=-21.7,
  23. symbol_rate=10e9,
  24. sample_rate=400e9,
  25. length=10,
  26. pulse_shape='rcos',
  27. sqrt_out=True
  28. )
  29. tf.executing_eagerly()
  30. aenc = Autoencoder(4, channel=ch)
  31. aenc.train(samples=1e6)
  32. plt.plot(*get_SNR(
  33. aenc.get_modulator(),
  34. aenc.get_demodulator(),
  35. ber_func=get_AWGN_ber,
  36. samples=100000,
  37. steps=50,
  38. start=-5,
  39. stop=15
  40. ), '-', label='AE')
  41. plt.plot(*get_SNR(
  42. AlphabetMod('4pam', 10e6),
  43. AlphabetDemod('4pam', 10e6),
  44. samples=30000,
  45. steps=50,
  46. start=-5,
  47. stop=15,
  48. length=1,
  49. pulse_shape='rcos'
  50. ), '-', label='4PAM')
  51. plt.yscale('log')
  52. plt.grid()
  53. plt.xlabel('SNR dB')
  54. plt.title("Autoencoder Performance")
  55. plt.legend()
  56. plt.savefig('optics_autoencoder.eps', format='eps')
  57. plt.show()
  58. def _test_autoencoder_pretrain():
  59. # aenc = Autoencoder(4, -25)
  60. # aenc.train(samples=1e6)
  61. # plt.plot(*get_SNR(
  62. # aenc.get_modulator(),
  63. # aenc.get_demodulator(),
  64. # ber_func=get_AWGN_ber,
  65. # samples=100000,
  66. # steps=50,
  67. # start=-5,
  68. # stop=15
  69. # ), '-', label='Random AE')
  70. aenc = Autoencoder(4, -25)
  71. # aenc.fit_encoder('16qam', 3e4)
  72. aenc.fit_decoder('16qam', 1e5)
  73. plt.plot(*get_SNR(
  74. aenc.get_modulator(),
  75. aenc.get_demodulator(),
  76. ber_func=get_AWGN_ber,
  77. samples=100000,
  78. steps=50,
  79. start=-5,
  80. stop=15
  81. ), '-', label='16QAM Pre-trained AE')
  82. aenc.train(samples=3e6)
  83. plt.plot(*get_SNR(
  84. aenc.get_modulator(),
  85. aenc.get_demodulator(),
  86. ber_func=get_AWGN_ber,
  87. samples=100000,
  88. steps=50,
  89. start=-5,
  90. stop=15
  91. ), '-', label='16QAM Post-trained AE')
  92. plt.plot(*get_SNR(
  93. AlphabetMod('16qam', 10e6),
  94. AlphabetDemod('16qam', 10e6),
  95. ber_func=get_AWGN_ber,
  96. samples=100000,
  97. steps=50,
  98. start=-5,
  99. stop=15
  100. ), '-', label='16QAM')
  101. plt.yscale('log')
  102. plt.grid()
  103. plt.xlabel('SNR dB')
  104. plt.title("4Bit Autoencoder Performance")
  105. plt.legend()
  106. plt.show()
  107. class LiteTFMod(defs.Modulator):
  108. def __init__(self, name, autoencoder):
  109. super().__init__(2 ** autoencoder.N)
  110. self.autoencoder = autoencoder
  111. tflite_models_dir = pathlib.Path("/tmp/tflite/")
  112. tflite_model_file = tflite_models_dir / (name + ".tflite")
  113. self.interpreter = tf.lite.Interpreter(model_path=str(tflite_model_file))
  114. self.interpreter.allocate_tensors()
  115. pass
  116. def forward(self, binary: np.ndarray):
  117. reshaped = binary.reshape((-1, (self.N * self.autoencoder.parallel)))
  118. reshaped_ho = misc.bit_matrix2one_hot(reshaped)
  119. input_index = self.interpreter.get_input_details()[0]["index"]
  120. input_dtype = self.interpreter.get_input_details()[0]["dtype"]
  121. input_shape = self.interpreter.get_input_details()[0]["shape"]
  122. output_index = self.interpreter.get_output_details()[0]["index"]
  123. output_shape = self.interpreter.get_output_details()[0]["shape"]
  124. x = np.zeros((len(reshaped_ho), output_shape[1]))
  125. for i, ho in enumerate(reshaped_ho):
  126. self.interpreter.set_tensor(input_index, ho.reshape(input_shape).astype(input_dtype))
  127. self.interpreter.invoke()
  128. x[i] = self.interpreter.get_tensor(output_index)
  129. if self.autoencoder.bipolar:
  130. x = x * 2 - 1
  131. if self.autoencoder.parallel > 1:
  132. x = x.reshape((-1, self.autoencoder.signal_dim))
  133. f = np.zeros(x.shape[0])
  134. if self.autoencoder.signal_dim <= 1:
  135. p = np.zeros(x.shape[0])
  136. else:
  137. p = x[:, 1]
  138. x3 = misc.rect2polar(np.c_[x[:, 0], p, f])
  139. return basic.RFSignal(x3)
  140. class LiteTFDemod(defs.Demodulator):
  141. def __init__(self, name, autoencoder):
  142. super().__init__(2 ** autoencoder.N)
  143. self.autoencoder = autoencoder
  144. tflite_models_dir = pathlib.Path("/tmp/tflite/")
  145. tflite_model_file = tflite_models_dir / (name + ".tflite")
  146. self.interpreter = tf.lite.Interpreter(model_path=str(tflite_model_file))
  147. self.interpreter.allocate_tensors()
  148. def forward(self, values: defs.Signal) -> np.ndarray:
  149. if self.autoencoder.signal_dim <= 1:
  150. val = values.rect_x
  151. else:
  152. val = values.rect
  153. if self.autoencoder.parallel > 1:
  154. val = val.reshape((-1, self.autoencoder.parallel))
  155. input_index = self.interpreter.get_input_details()[0]["index"]
  156. input_dtype = self.interpreter.get_input_details()[0]["dtype"]
  157. input_shape = self.interpreter.get_input_details()[0]["shape"]
  158. output_index = self.interpreter.get_output_details()[0]["index"]
  159. output_shape = self.interpreter.get_output_details()[0]["shape"]
  160. decoded = np.zeros((len(val), output_shape[1]))
  161. for i, v in enumerate(val):
  162. self.interpreter.set_tensor(input_index, v.reshape(input_shape).astype(input_dtype))
  163. self.interpreter.invoke()
  164. decoded[i] = self.interpreter.get_tensor(output_index)
  165. result = misc.int2bit_array(decoded.argmax(axis=1), self.N * self.autoencoder.parallel)
  166. return result.reshape(-1, )
  167. def _test_autoencoder_perf():
  168. assert float(tf.__version__[:3]) >= 2.3
  169. # aenc = Autoencoder(3, -15)
  170. # aenc.train(samples=1e6)
  171. # plt.plot(*get_SNR(
  172. # aenc.get_modulator(),
  173. # aenc.get_demodulator(),
  174. # ber_func=get_AWGN_ber,
  175. # samples=100000,
  176. # steps=50,
  177. # start=-5,
  178. # stop=15
  179. # ), '-', label='3Bit AE')
  180. # aenc = Autoencoder(4, -25, bipolar=True, dtype=tf.float64)
  181. # aenc.train(samples=5e5)
  182. # plt.plot(*get_SNR(
  183. # aenc.get_modulator(),
  184. # aenc.get_demodulator(),
  185. # ber_func=get_AWGN_ber,
  186. # samples=100000,
  187. # steps=50,
  188. # start=-5,
  189. # stop=15
  190. # ), '-', label='4Bit AE F64')
  191. aenc = Autoencoder(4, -25, bipolar=True)
  192. aenc.train(epoch_size=1e3, epochs=10)
  193. # #
  194. m = aenc.N * aenc.parallel
  195. x_train = misc.bit_matrix2one_hot(misc.generate_random_bit_array(100*m).reshape((-1, m)))
  196. x_train_enc = aenc.encoder(x_train)
  197. x_train = tf.cast(x_train, tf.float32)
  198. #
  199. # plt.plot(*get_SNR(
  200. # aenc.get_modulator(),
  201. # aenc.get_demodulator(),
  202. # ber_func=get_AWGN_ber,
  203. # samples=100000,
  204. # steps=50,
  205. # start=-5,
  206. # stop=15
  207. # ), '-', label='4AE F32')
  208. # # #
  209. def save_tfline(model, name, types=None, ops=None, io_types=None, train_x=None):
  210. converter = tf.lite.TFLiteConverter.from_keras_model(model)
  211. if types is not None:
  212. converter.optimizations = [tf.lite.Optimize.DEFAULT]
  213. converter.target_spec.supported_types = types
  214. if ops is not None:
  215. converter.optimizations = [tf.lite.Optimize.DEFAULT]
  216. converter.target_spec.supported_ops = ops
  217. if io_types is not None:
  218. converter.inference_input_type = io_types
  219. converter.inference_output_type = io_types
  220. if train_x is not None:
  221. def representative_data_gen():
  222. for input_value in tf.data.Dataset.from_tensor_slices(train_x).batch(1).take(100):
  223. yield [input_value]
  224. converter.representative_dataset = representative_data_gen
  225. tflite_model = converter.convert()
  226. tflite_models_dir = pathlib.Path("/tmp/tflite/")
  227. tflite_models_dir.mkdir(exist_ok=True, parents=True)
  228. tflite_model_file = tflite_models_dir / (name + ".tflite")
  229. tflite_model_file.write_bytes(tflite_model)
  230. print("Saving models")
  231. save_tfline(aenc.encoder, "default_enc")
  232. save_tfline(aenc.decoder, "default_dec")
  233. #
  234. # save_tfline(aenc.encoder, "float16_enc", [tf.float16])
  235. # save_tfline(aenc.decoder, "float16_dec", [tf.float16])
  236. #
  237. # save_tfline(aenc.encoder, "bfloat16_enc", [tf.bfloat16])
  238. # save_tfline(aenc.decoder, "bfloat16_dec", [tf.bfloat16])
  239. INT16X8 = tf.lite.OpsSet.EXPERIMENTAL_TFLITE_BUILTINS_ACTIVATIONS_INT16_WEIGHTS_INT8
  240. save_tfline(aenc.encoder, "int16x8_enc", ops=[INT16X8], train_x=x_train)
  241. save_tfline(aenc.decoder, "int16x8_dec", ops=[INT16X8], train_x=x_train_enc)
  242. # save_tfline(aenc.encoder, "int8_enc", ops=[tf.lite.OpsSet.TFLITE_BUILTINS_INT8], io_types=tf.uint8, train_x=x_train)
  243. # save_tfline(aenc.decoder, "int8_dec", ops=[tf.lite.OpsSet.TFLITE_BUILTINS_INT8], io_types=tf.uint8, train_x=x_train_enc)
  244. print("Testing BER vs SNR")
  245. plt.plot(*get_SNR(
  246. LiteTFMod("default_enc", aenc),
  247. LiteTFDemod("default_dec", aenc),
  248. ber_func=get_AWGN_ber,
  249. samples=100000,
  250. steps=50,
  251. start=-5,
  252. stop=15
  253. ), '-', label='4AE F32')
  254. # plt.plot(*get_SNR(
  255. # LiteTFMod("float16_enc", aenc),
  256. # LiteTFDemod("float16_dec", aenc),
  257. # ber_func=get_AWGN_ber,
  258. # samples=100000,
  259. # steps=50,
  260. # start=-5,
  261. # stop=15
  262. # ), '-', label='4AE F16')
  263. # #
  264. # plt.plot(*get_SNR(
  265. # LiteTFMod("bfloat16_enc", aenc),
  266. # LiteTFDemod("bfloat16_dec", aenc),
  267. # ber_func=get_AWGN_ber,
  268. # samples=100000,
  269. # steps=50,
  270. # start=-5,
  271. # stop=15
  272. # ), '-', label='4AE BF16')
  273. #
  274. plt.plot(*get_SNR(
  275. LiteTFMod("int16x8_enc", aenc),
  276. LiteTFDemod("int16x8_dec", aenc),
  277. ber_func=get_AWGN_ber,
  278. samples=100000,
  279. steps=50,
  280. start=-5,
  281. stop=15
  282. ), '-', label='4AE I16x8')
  283. # plt.plot(*get_SNR(
  284. # AlphabetMod('16qam', 10e6),
  285. # AlphabetDemod('16qam', 10e6),
  286. # ber_func=get_AWGN_ber,
  287. # samples=100000,
  288. # steps=50,
  289. # start=-5,
  290. # stop=15,
  291. # ), '-', label='16qam')
  292. plt.yscale('log')
  293. plt.grid()
  294. plt.xlabel('SNR dB')
  295. plt.ylabel('BER')
  296. plt.title("Autoencoder with different precision data types")
  297. plt.legend()
  298. plt.savefig('autoencoder_compression.eps', format='eps')
  299. plt.show()
  300. view_encoder(aenc.encoder, 4)
  301. # aenc = Autoencoder(5, -25)
  302. # aenc.train(samples=2e6)
  303. # plt.plot(*get_SNR(
  304. # aenc.get_modulator(),
  305. # aenc.get_demodulator(),
  306. # ber_func=get_AWGN_ber,
  307. # samples=100000,
  308. # steps=50,
  309. # start=-5,
  310. # stop=15
  311. # ), '-', label='5Bit AE')
  312. #
  313. # aenc = Autoencoder(6, -25)
  314. # aenc.train(samples=2e6)
  315. # plt.plot(*get_SNR(
  316. # aenc.get_modulator(),
  317. # aenc.get_demodulator(),
  318. # ber_func=get_AWGN_ber,
  319. # samples=100000,
  320. # steps=50,
  321. # start=-5,
  322. # stop=15
  323. # ), '-', label='6Bit AE')
  324. #
  325. # for scheme in ['64qam', '32qam', '16qam', 'qpsk', '8psk']:
  326. # plt.plot(*get_SNR(
  327. # AlphabetMod(scheme, 10e6),
  328. # AlphabetDemod(scheme, 10e6),
  329. # ber_func=get_AWGN_ber,
  330. # samples=100000,
  331. # steps=50,
  332. # start=-5,
  333. # stop=15,
  334. # ), '-', label=scheme.upper())
  335. #
  336. # plt.yscale('log')
  337. # plt.grid()
  338. # plt.xlabel('SNR dB')
  339. # plt.title("Autoencoder vs defined modulations")
  340. # plt.legend()
  341. # plt.show()
  342. def _test_autoencoder_perf2():
  343. aenc = Autoencoder(2, -20)
  344. aenc.train(samples=3e6)
  345. plt.plot(*get_SNR(aenc.get_modulator(), aenc.get_demodulator(), ber_func=get_AWGN_ber, samples=100000, steps=50, start=-5, stop=15), '-', label='2Bit AE')
  346. aenc = Autoencoder(3, -20)
  347. aenc.train(samples=3e6)
  348. plt.plot(*get_SNR(aenc.get_modulator(), aenc.get_demodulator(), ber_func=get_AWGN_ber, samples=100000, steps=50, start=-5, stop=15), '-', label='3Bit AE')
  349. aenc = Autoencoder(4, -20)
  350. aenc.train(samples=3e6)
  351. plt.plot(*get_SNR(aenc.get_modulator(), aenc.get_demodulator(), ber_func=get_AWGN_ber, samples=100000, steps=50, start=-5, stop=15), '-', label='4Bit AE')
  352. aenc = Autoencoder(5, -20)
  353. aenc.train(samples=3e6)
  354. plt.plot(*get_SNR(aenc.get_modulator(), aenc.get_demodulator(), ber_func=get_AWGN_ber, samples=100000, steps=50, start=-5, stop=15), '-', label='5Bit AE')
  355. for a in ['qpsk', '8psk', '16qam', '32qam', '64qam']:
  356. try:
  357. plt.plot(*get_SNR(AlphabetMod(a, 10e6), AlphabetDemod(a, 10e6), ber_func=get_AWGN_ber, samples=100000, steps=50, start=-5, stop=15,), '-', label=a.upper())
  358. except KeyboardInterrupt:
  359. break
  360. except Exception:
  361. pass
  362. plt.yscale('log')
  363. plt.grid()
  364. plt.xlabel('SNR dB')
  365. plt.title("Autoencoder vs defined modulations")
  366. plt.legend()
  367. plt.savefig('autoencoder_mods.eps', format='eps')
  368. plt.show()
  369. # view_encoder(aenc.encoder, 2)
  370. def _test_autoencoder_perf_qnn():
  371. fh = logging.FileHandler("model_quantizing.log", mode="w+")
  372. fh.setLevel(logging.INFO)
  373. sh = logging.StreamHandler(stream=stdout)
  374. sh.setLevel(logging.INFO)
  375. logger = logging.getLogger(__name__)
  376. logger.setLevel(level=logging.INFO)
  377. logger.addHandler(fh)
  378. logger.addHandler(sh)
  379. aenc = Autoencoder(4, -25, bipolar=True)
  380. # aenc.train(epoch_size=1e3, epochs=10)
  381. # aenc.encoder.save_weights('ae_enc.bin')
  382. # aenc.decoder.save_weights('ae_dec.bin')
  383. # aenc.encoder.load_weights('ae_enc.bin')
  384. # aenc.decoder.load_weights('ae_dec.bin')
  385. aenc.load_weights('autoencoder.bin')
  386. aenc.compile(optimizer='adam', loss=tf.losses.MeanSquaredError())
  387. aenc.build()
  388. aenc.summary()
  389. m = aenc.N * aenc.parallel
  390. view_encoder(aenc.encoder, 4, title="FP32 Alphabet")
  391. batch_size = 25000
  392. x_train = misc.bit_matrix2one_hot(misc.generate_random_bit_array(batch_size*m).reshape((-1, m)))
  393. x_test = misc.bit_matrix2one_hot(misc.generate_random_bit_array(5000*m).reshape((-1, m)))
  394. bits = [np.log2(i) for i in (32,)][0]
  395. alphabet_scalars = 2 # [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
  396. num_layers = sum([layer.__class__.__name__ in ('Dense',) for layer in aenc.all_layers])
  397. # for b in (3, 6, 8, 12, 16, 24, 32, 48, 64):
  398. get_data = (sample for sample in x_train)
  399. for i in range(num_layers):
  400. get_data = chain(get_data, (sample for sample in x_train))
  401. qnn = QuantizedNeuralNetwork(
  402. network=aenc,
  403. batch_size=batch_size,
  404. get_data=get_data,
  405. logger=logger,
  406. bits=np.log2(16),
  407. alphabet_scalar=alphabet_scalars,
  408. )
  409. qnn.quantize_network()
  410. qnn.quantized_net.compile(optimizer='adam', loss=tf.keras.losses.MeanSquaredError())
  411. view_encoder(qnn.quantized_net, 4, title=f"Quantised {16}b alphabet")
  412. # view_encoder(qnn_enc.quantized_net, 4, title=f"Quantised {b}b alphabet")
  413. # _, q_accuracy = qnn.quantized_net.evaluate(x_test, x_test, verbose=True)
  414. pass
  415. if __name__ == '__main__':
  416. # plt.plot(*get_SNR(
  417. # AlphabetMod('16qam', 10e6),
  418. # AlphabetDemod('16qam', 10e6),
  419. # ber_func=get_AWGN_ber,
  420. # samples=100000,
  421. # steps=50,
  422. # start=-5,
  423. # stop=15,
  424. # ), '-', label='16qam AWGN')
  425. #
  426. # plt.plot(*get_SNR(
  427. # AlphabetMod('16qam', 10e6),
  428. # AlphabetDemod('16qam', 10e6),
  429. # samples=100000,
  430. # steps=50,
  431. # start=-5,
  432. # stop=15,
  433. # ), '-', label='16qam OPTICAL')
  434. #
  435. # plt.yscale('log')
  436. # plt.grid()
  437. # plt.xlabel('SNR dB')
  438. # plt.show()
  439. # _test_autoencoder_perf()
  440. _test_autoencoder_perf_qnn()
  441. # _test_autoencoder_perf2()
  442. # _test_autoencoder_pretrain()
  443. # _test_optics_autoencoder()