Explorar o código

Added missing layers & fixed new_model

Min %!s(int64=4) %!d(string=hai) anos
pai
achega
4e4c333461
Modificáronse 2 ficheiros con 31 adicións e 3 borrados
  1. 29 0
      models/layers.py
  2. 2 3
      models/new_model.py

+ 29 - 0
models/layers.py

@@ -26,6 +26,35 @@ class AwgnChannel(layers.Layer):
         return self.noise_layer.call(inputs, training=True)
 
 
+class BitsToSymbols(layers.Layer):
+    def __init__(self, cardinality):
+        super(BitsToSymbols, self).__init__()
+
+        self.cardinality = cardinality
+
+        n = int(tf.math.log(self.cardinality, 2))
+        self.pows = tf.convert_to_tensor(np.power(2, np.linspace(n-1, 0, n)).reshape(-1, 1), dtype=tf.float32)
+
+    def call(self, inputs, **kwargs):
+        idx = tf.cast(tf.tensordot(inputs, self.pows, axes=1), dtype=tf.int32)
+        out = tf.one_hot(idx, self.cardinality)
+        return layers.Reshape((9, 32))(out)
+
+
+class SymbolsToBits(layers.Layer):
+    def __init__(self, cardinality):
+        super(SymbolsToBits, self).__init__()
+
+        n = int(tf.math.log(cardinality, 2))
+        lst = [list(i) for i in itertools.product([0, 1], repeat=n)]
+
+        # self.all_syms = tf.convert_to_tensor(np.asarray(lst), dtype=tf.float32)
+        self.all_syms = tf.convert_to_tensor(np.asarray(lst), dtype=tf.float32)
+
+    def call(self, inputs, **kwargs):
+        return tf.matmul(inputs, self.all_syms)
+
+
 class ScaleAndOffset(layers.Layer):
     """
     Scales and offsets a tensor

+ 2 - 3
models/new_model.py

@@ -1,8 +1,7 @@
 import tensorflow as tf
-from tensorflow.keras import layers, losses
-from custom_layers import ExtractCentralMessage, OpticalChannel
+from tensorflow.keras import losses
+from layers import OpticalChannel, BitsToSymbols, SymbolsToBits
 from end_to_end import EndToEndAutoencoder
-from custom_layers import BitsToSymbols, SymbolsToBits
 import numpy as np
 import math