Explorar o código

Non-dividable sample number fix

Min %!s(int64=5) %!d(string=hai) anos
pai
achega
8810bce77c
Modificáronse 1 ficheiros con 5 adicións e 1 borrados
  1. 5 1
      models/autoencoder.py

+ 5 - 1
models/autoencoder.py

@@ -5,6 +5,7 @@ import tensorflow as tf
 from sklearn.metrics import accuracy_score
 from sklearn.metrics import accuracy_score
 from tensorflow.keras import layers, losses
 from tensorflow.keras import layers, losses
 from tensorflow.keras.models import Model
 from tensorflow.keras.models import Model
+from functools import partial
 import misc
 import misc
 import defs
 import defs
 
 
@@ -99,7 +100,10 @@ class Autoencoder(Model):
         x_train = misc.generate_random_bit_array(samples).reshape((-1, self.N))
         x_train = misc.generate_random_bit_array(samples).reshape((-1, self.N))
         x_train_ho = misc.bit_matrix2one_hot(x_train)
         x_train_ho = misc.bit_matrix2one_hot(x_train)
 
 
-        x_test_array = misc.generate_random_bit_array(samples * 0.3)
+        test_samples = samples * 0.3
+        if test_samples % self.N:
+            test_samples += self.N - (test_samples % self.N)
+        x_test_array = misc.generate_random_bit_array(test_samples)
         x_test = x_test_array.reshape((-1, self.N))
         x_test = x_test_array.reshape((-1, self.N))
         x_test_ho = misc.bit_matrix2one_hot(x_test)
         x_test_ho = misc.bit_matrix2one_hot(x_test)