|
@@ -7,6 +7,8 @@ import pathlib
|
|
|
from itertools import chain
|
|
from itertools import chain
|
|
|
from sys import stdout
|
|
from sys import stdout
|
|
|
|
|
|
|
|
|
|
+from tensorflow.python.framework.errors_impl import NotFoundError
|
|
|
|
|
+
|
|
|
import defs
|
|
import defs
|
|
|
from graphs import get_SNR, get_AWGN_ber
|
|
from graphs import get_SNR, get_AWGN_ber
|
|
|
from models import basic
|
|
from models import basic
|
|
@@ -437,15 +439,17 @@ def _test_autoencoder_perf_qnn():
|
|
|
logger.addHandler(sh)
|
|
logger.addHandler(sh)
|
|
|
|
|
|
|
|
aenc = Autoencoder(4, -25, bipolar=True)
|
|
aenc = Autoencoder(4, -25, bipolar=True)
|
|
|
- # aenc.train(epoch_size=1e3, epochs=10)
|
|
|
|
|
# aenc.encoder.save_weights('ae_enc.bin')
|
|
# aenc.encoder.save_weights('ae_enc.bin')
|
|
|
# aenc.decoder.save_weights('ae_dec.bin')
|
|
# aenc.decoder.save_weights('ae_dec.bin')
|
|
|
# aenc.encoder.load_weights('ae_enc.bin')
|
|
# aenc.encoder.load_weights('ae_enc.bin')
|
|
|
# aenc.decoder.load_weights('ae_dec.bin')
|
|
# aenc.decoder.load_weights('ae_dec.bin')
|
|
|
- aenc.load_weights('autoencoder.bin')
|
|
|
|
|
|
|
+ try:
|
|
|
|
|
+ aenc.load_weights('autoencoder')
|
|
|
|
|
+ except NotFoundError:
|
|
|
|
|
+ aenc.train(epoch_size=1e3, epochs=10)
|
|
|
|
|
+ aenc.save_weights('autoencoder')
|
|
|
|
|
+
|
|
|
aenc.compile(optimizer='adam', loss=tf.losses.MeanSquaredError())
|
|
aenc.compile(optimizer='adam', loss=tf.losses.MeanSquaredError())
|
|
|
- aenc.build()
|
|
|
|
|
- aenc.summary()
|
|
|
|
|
|
|
|
|
|
m = aenc.N * aenc.parallel
|
|
m = aenc.N * aenc.parallel
|
|
|
view_encoder(aenc.encoder, 4, title="FP32 Alphabet")
|
|
view_encoder(aenc.encoder, 4, title="FP32 Alphabet")
|