Bläddra i källkod

layer merge fix

Min 4 år sedan
förälder
incheckning
91ef1d49d0
1 ändrade filer med 4 tillägg och 29 borttagningar
  1. 4 29
      models/layers.py

+ 4 - 29
models/layers.py

@@ -6,6 +6,7 @@ import itertools
 from tensorflow.keras import layers
 import tensorflow as tf
 import numpy as np
+import math
 
 
 class AwgnChannel(layers.Layer):
@@ -32,8 +33,8 @@ class BitsToSymbols(layers.Layer):
 
         self.cardinality = cardinality
 
-        n = int(tf.math.log(self.cardinality, 2))
-        self.pows = tf.convert_to_tensor(np.power(2, np.linspace(n-1, 0, n)).reshape(-1, 1), dtype=tf.float32)
+        n = int(math.log(self.cardinality, 2))
+        self.pows = tf.convert_to_tensor(np.power(2, np.linspace(n - 1, 0, n)).reshape(-1, 1), dtype=tf.float32)
 
     def call(self, inputs, **kwargs):
         idx = tf.cast(tf.tensordot(inputs, self.pows, axes=1), dtype=tf.int32)
@@ -45,7 +46,7 @@ class SymbolsToBits(layers.Layer):
     def __init__(self, cardinality):
         super(SymbolsToBits, self).__init__()
 
-        n = int(tf.math.log(cardinality, 2))
+        n = int(math.log(cardinality, 2))
         lst = [list(i) for i in itertools.product([0, 1], repeat=n)]
 
         # self.all_syms = tf.convert_to_tensor(np.asarray(lst), dtype=tf.float32)
@@ -69,32 +70,6 @@ class ScaleAndOffset(layers.Layer):
         return inputs * self.scale + self.offset
 
 
-class BitsToSymbol(layers.Layer):
-    def __init__(self, cardinality, **kwargs):
-        super().__init__(**kwargs)
-        self.cardinality = cardinality
-        n = int(np.log(self.cardinality, 2))
-        self.powers = tf.convert_to_tensor(
-            np.power(2, np.linspace(n - 1, 0, n)).reshape(-1, 1),
-            dtype=tf.float32
-        )
-
-    def call(self, inputs, **kwargs):
-        idx = tf.cast(tf.tensordot(inputs, self.powers, axes=1), dtype=tf.int32)
-        return tf.one_hot(idx, self.cardinality)
-
-
-class SymbolToBits(layers.Layer):
-    def __init__(self, cardinality, **kwargs):
-        super().__init__(**kwargs)
-        n = int(np.log(cardinality, 2))
-        l = [list(i) for i in itertools.product([0, 1], repeat=n)]
-        self.all_syms = tf.transpose(tf.convert_to_tensor(np.asarray(l), dtype=tf.float32))
-
-    def call(self, inputs, **kwargs):
-        return tf.matmul(self.all_syms, inputs)
-
-
 class ExtractCentralMessage(layers.Layer):
     def __init__(self, messages_per_block, samples_per_symbol):
         """