瀏覽代碼

Added one-hot to autoencoder

Min 5 年之前
父節點
當前提交
5de0515f99
共有 1 個文件被更改,包括 23 次插入18 次删除
  1. 23 18
      models/autoencoder.py

+ 23 - 18
models/autoencoder.py

@@ -18,17 +18,19 @@ class Autoencoder(Model):
         super(Autoencoder, self).__init__()
         self.latent_dim = latent_dim
         self.encoder = tf.keras.Sequential()
-        self.encoder.add(tf.keras.Input(shape=(nary,), dtype=bool))
-        self.encoder.add(layers.Dense(units=16))
+        self.encoder.add(tf.keras.Input(shape=(2**nary,), dtype=bool))
+        self.encoder.add(layers.Dense(units=2**(nary+1)))
         # self.encoder.add(layers.Dropout(0.2))
-        self.encoder.add(layers.Dense(units=2))
+        self.encoder.add(layers.Dense(units=2**(nary+1)))
+        self.encoder.add(layers.Dense(units=2, activation="sigmoid"))
         # self.encoder.add(layers.ReLU(max_value=1.0))
 
         self.decoder = tf.keras.Sequential()
         self.decoder.add(tf.keras.Input(shape=(2,)))
-        self.decoder.add(layers.Dense(units=16))
+        self.decoder.add(layers.Dense(units=2**(nary+1)))
         # self.decoder.add(layers.Dropout(0.2))
-        self.decoder.add(layers.Dense(units=nary))
+        self.decoder.add(layers.Dense(units=2**(nary+1)))
+        self.decoder.add(layers.Dense(units=2**nary, activation="softmax"))
 
         self.randomiser = tf.random_normal_initializer(mean=0.0, stddev=0.1, seed=None)
         # self.decoder.add(layers.Softmax(units=4, dtype=bool))
@@ -46,7 +48,7 @@ class Autoencoder(Model):
 
     def call(self, x, **kwargs):
         encoded = self.encoder(x)
-        encoded = tf.clip_by_value(encoded,  clip_value_min=[0, 0], clip_value_max=[1, 2*np.pi], name=None)
+        encoded = tf.clip_by_value(encoded,  clip_value_min=0, clip_value_max=2, name=None)
         # noise = self.randomiser(shape=(-1, 2), dtype=tf.float32)
         noise = np.random.normal(0, 1, (1, 2)) * 0.2
         noisy = tf.convert_to_tensor(noise, dtype=tf.float32)
@@ -56,9 +58,11 @@ class Autoencoder(Model):
 
 def view_encoder(encoder, N, samples=1000):
     test_values = misc.generate_random_bit_array(samples).reshape((-1, N))
+    test_values_ho = misc.bit_matrix2one_hot(test_values)
     mvector = np.array([2**i for i in range(N)], dtype=int)
     symbols = (test_values * mvector).sum(axis=1)
-    encoded = misc.polar2rect(encoder(test_values).numpy())
+    encoded = encoder(test_values_ho).numpy()
+    # encoded = misc.polar2rect(encoded)
     for i in range(2**N):
         xy = encoded[symbols == i]
         plt.plot(xy[:, 0], xy[:, 1], 'x', markersize=12, label=format(i, f'0{N}b'))
@@ -81,27 +85,28 @@ if __name__ == '__main__':
     # print(f"Train data: {x_train.shape}")
     # print(f"Test data: {x_test.shape}")
 
-    samples = 3e6
     n = 4
+
+    samples = 3e6
     x_train = misc.generate_random_bit_array(samples).reshape((-1, n))
-    x_test_array = misc.generate_random_bit_array(samples * 0.25)
+    x_train_ho = misc.bit_matrix2one_hot(x_train)
+    x_test_array = misc.generate_random_bit_array(samples * 0.3)
     x_test = x_test_array.reshape((-1, n))
+    x_test_ho = misc.bit_matrix2one_hot(x_test)
 
     autoencoder = Autoencoder(n)
     autoencoder.compile(optimizer='adam', loss=losses.MeanSquaredError())
 
-    autoencoder.fit(x_train, x_train,
+    autoencoder.fit(x_train_ho, x_train_ho,
                     epochs=1,
-                    shuffle=True,
-                    validation_data=(x_test, x_test))
-
-    encoded_data = autoencoder.encoder(x_test)
-    decoded_data = autoencoder.decoder(encoded_data).numpy().reshape((-1,))
+                    shuffle=False,
+                    validation_data=(x_test_ho, x_test_ho))
 
-    result = np.zeros(x_test_array.shape, dtype=bool)
-    result[decoded_data > 0.5] = True
+    encoded_data = autoencoder.encoder(x_test_ho)
+    decoded_data = autoencoder.decoder(encoded_data).numpy()
 
-    print("Accuracy: %.4f" % accuracy_score(x_test_array, result))
+    result = misc.int2bit_array(decoded_data.argmax(axis=1), n)
+    print("Accuracy: %.4f" % accuracy_score(x_test_array, result.reshape(-1,)))
     view_encoder(autoencoder.encoder, n)
 
     pass