Jamie's Blog

Quick Snippet: Applying Keras Layers

Tags: Machine Learning, Programming

When defining neural networks, even in a framework as concise as Keras, you often find yourself writing far too much gumpf. An example might look like this:

# encoder
x = Conv3D(64, (3, 3, 3), data_format="channels_first", activation="relu", padding="same")(x)
x = MaxPooling3D((1,2,2), data_format="channels_first")(x)
x = Conv3D(128, (3, 3, 3), data_format="channels_first", activation="relu", padding="same")(x)
x = MaxPooling3D((2,2,2), data_format="channels_first")(x)
x = Reshape((128,7,7))(x)
# decoder
x = BatchNormalization()(x)
x = Conv2D(64, (3,3), data_format="channels_first", padding="same")(x)
x = LeakyReLU(0.001)(x)
x = UpSampling2D(data_format="channels_first")(x),
x = BatchNormalization()(x)
x = Conv2D(32, (3,3), data_format="channels_first", padding="same")(x)
x = LeakyReLU(0.001)(x)
x = UpSampling2D(data_format="channels_first")(x)
x = BatchNormalization()(x)
x = Conv2D(1, (3,3), data_format="channels_first", padding="same")(x)
x = LeakyReLU(0.001)(x)
x = UpSampling2D(data_format="channels_first")(x)
x = Reshape((112,112))(x)

On each line, we define a Keras layer, and then immediately call it with x, storing the result in x. This leads to unconscionable redundancy – and this is an abridged version of a much larger network! Brevity is the soul of wit, after all, and I consider myself very witty. So, how can we do better?

Anyone who knows me knows I’m a huge fan of Haskell, a language in which this sort of thing simply wouldn’t stand – instead, it would be natural to view this as a reduction of a list of callables over an initial input. Python lacks some of the machinery which make this so easy in Haskell, but we can easily bootstrap the same effect with this code:

apply = lambda f, x: f(x)
def flip(f):
    return lambda a, b: f(b, a)
apply_sequence = lambda l, x: reduce(flip(apply), l, x)

apply_sequence takes l, a list of callables like our Keras layer instances, and applies them sequentially to an input x. To do this we simply map a function which applies a function to an argument, with its arguments flipped to match the signature of reduce.

Now we can write our network out as:

architecture = [
    # encoder
    Conv3D(64, (3, 3, 3), data_format="channels_first", activation="relu", padding="same"),
    MaxPooling3D((1,2,2), data_format="channels_first"),
    Conv3D(128, (3, 3, 3), data_format="channels_first", activation="relu", padding="same"),
    MaxPooling3D((2,2,2), data_format="channels_first"),
    Reshape((128,7,7)),
    # decoder,
    BatchNormalization(),
    Conv2D(64, (3,3), data_format="channels_first", padding="same"),
    LeakyReLU(0.001),
    UpSampling2D(data_format="channels_first"),
    BatchNormalization(),
    Conv2D(32, (3,3), data_format="channels_first", padding="same"),
    LeakyReLU(0.001),
    UpSampling2D(data_format="channels_first"),
    BatchNormalization(),
    Conv2D(1, (3,3), data_format="channels_first", padding="same"),
    Activation("relu"),
    UpSampling2D(data_format="channels_first"),
    Reshape((112,112))
]

x = Input((112,112))
x = apply_sequence(architecture, x)

which gets rid of all those pesky xs. It also makes it easy to apply the same layers to multiple inputs with shared weights (as in a siamese network):

left = Input((112,112))
right = Input((112,112))
left = apply_sequence(architecture, left)
right = apply_sequence(architecture, right)

Obviously this is still pretty horrible. Many of the arguments are shared and many functions have only one variable which changes. You could define the shared values as named constants somewhere else, but you’d still have to type the variable name lots of times. With some cunning function definitions, however, we can simplify this down to

data_format = "channels_first"
BN = BatchNormalization
C3 = lambda filter_size: Conv3D(
        filter_size,
        (3, 3, 3),
        data_format=data_format,
        activation="relu",
        padding="same")
def P3(shape=(2, 2, 2),strides=None):
    return MaxPooling3D(
        shape,
        strides=strides,
        data_format=data_format)
C2 = lambda filter_size: Conv2D(
        filter_size,
        (3,3),
        data_format=data_format,
        padding="same")
U2 = lambda: UpSampling2D(data_format=data_format)
LR = lambda: LeakyReLU(0.001)

architecture = [
    # encoder
    C3(64), P3((1,2,2)),
    C3(128), P3(),
    Reshape((128,7,7)),
    # decoder
    BN(), C2(64),  LR(), U2(),
    BN(), C2(32),  LR(), U2(),
    BN(), C2(1),  Activation("relu"),
    Reshape((112,112))
]

left = Input((112,112))
right = Input((112,112))
left = apply_sequence(architecture, left)
right = apply_sequence(architecture, right)

This makes reading and understanding our network much easier, and changes less error-prone. We only have to change one line to try a different kernel size in our Conv3Ds, and it’s a mere 20-odd characters to add a whole new decoder unit to the architecture.