Sfoglia il codice sorgente

Excluding unwanted files

Min 4 anni fa
parent
commit
630a5caf79
2 ha cambiato i file con 15 aggiunte e 4 eliminazioni
  1. 7 0
      .gitignore
  2. 8 4
      tests/min_test.py

+ 7 - 0
.gitignore

@@ -5,3 +5,10 @@ __pycache__
 
 # Environments
 venv/
+
+# Anything else
+*.log
+checkpoint
+*.index
+*.data-[0-9][0-9][0-9][0-9][0-9]-of-[0-9][0-9][0-9][0-9][0-9]
+other/

+ 8 - 4
tests/min_test.py

@@ -7,6 +7,8 @@ import pathlib
 from itertools import chain
 from sys import stdout
 
+from tensorflow.python.framework.errors_impl import NotFoundError
+
 import defs
 from graphs import get_SNR, get_AWGN_ber
 from models import basic
@@ -437,15 +439,17 @@ def _test_autoencoder_perf_qnn():
     logger.addHandler(sh)
 
     aenc = Autoencoder(4, -25, bipolar=True)
-    # aenc.train(epoch_size=1e3, epochs=10)
     # aenc.encoder.save_weights('ae_enc.bin')
     # aenc.decoder.save_weights('ae_dec.bin')
     # aenc.encoder.load_weights('ae_enc.bin')
     # aenc.decoder.load_weights('ae_dec.bin')
-    aenc.load_weights('autoencoder.bin')
+    try:
+        aenc.load_weights('autoencoder')
+    except NotFoundError:
+        aenc.train(epoch_size=1e3, epochs=10)
+        aenc.save_weights('autoencoder')
+
     aenc.compile(optimizer='adam', loss=tf.losses.MeanSquaredError())
-    aenc.build()
-    aenc.summary()
 
     m = aenc.N * aenc.parallel
     view_encoder(aenc.encoder, 4, title="FP32 Alphabet")