Batch Normalization 简单理解

1:背景

由于在训练神经网络的过程中,每一层的 params是不断更新的,由于params的更新会导致下一层输入的分布情况发生改变,所以这就要求我们进行权重初始化,减小学习率。这个现象就叫做internal covariate shift。

2:idea思想

虽然可以通过whitening来加速收敛,但是需要的计算资源会很大。

而Batch Normalizationn的思想则是对于每一组batch,在网络的每一层中,分feature对输入进行normalization,对各个feature分别normalization,即对网络中每一层的单个神经元输入,计算均值和方差后,再进行normalization。

对于CNN来说normalize “Wx+b”而非 “x”,也可以忽略掉b,即normalize “Wx”,而计算均值和方差的时候,是在feature map的基础上(原来是每一个feature)

3:算法流程(对network进行normalize)

算法一

这里写图片描述

算法二

这里写图片描述

4:代码(keras)

'''
    Reference:
        Batch Normalization: Accelerating Deep Network Training by Reducing Internal Covariate Shift
            http://arxiv.org/pdf/1502.03167v3.pdf
        mode: 0 -> featurewise normalization
              1 -> samplewise normalization (may sometimes outperform featurewise mode)
        momentum: momentum term in the computation of a running estimate of the mean and std of the data
'''
def __init__(self, input_shape, epsilon=1e-6, mode=0, momentum=0.9, weights=None):
    super(BatchNormalization, self).__init__()
    self.init = initializations.get("uniform")
    self.input_shape = input_shape
    self.epsilon = epsilon
    self.mode = mode
    self.momentum = momentum
    self.input = ndim_tensor(len(self.input_shape) + 1)

    self.gamma = self.init((self.input_shape))
    self.beta = shared_zeros(self.input_shape)

    self.params = [self.gamma, self.beta]
    self.running_mean = shared_zeros(self.input_shape)
    self.running_std = shared_ones((self.input_shape))
    if weights is not None:
        self.set_weights(weights)

def get_weights(self):
    return super(BatchNormalization, self).get_weights() + [self.running_mean.get_value(), self.running_std.get_value()]

def set_weights(self, weights):
    self.running_mean.set_value(floatX(weights[-2]))
    self.running_std.set_value(floatX(weights[-1]))
    super(BatchNormalization, self).set_weights(weights[:-2])

def init_updates(self):
    X = self.get_input(train=True)
    m = X.mean(axis=0)
    std = T.mean((X - m) ** 2 + self.epsilon, axis=0) ** 0.5
    mean_update = self.momentum * self.running_mean + (1-self.momentum) * m
    std_update = self.momentum * self.running_std + (1-self.momentum) * std
    self.updates = [(self.running_mean, mean_update), (self.running_std, std_update)]

def get_output(self, train):
    X = self.get_input(train)

    if self.mode == 0:
        X_normed = (X - self.running_mean) / (self.running_std + self.epsilon)

    elif self.mode == 1:
        m = X.mean(axis=-1, keepdims=True)
        std = X.std(axis=-1, keepdims=True)
        X_normed = (X - m) / (std + self.epsilon)

    out = self.gamma * X_normed + self.beta
    return out

def get_config(self):
    return {"name": self.__class__.__name__,
            "input_shape": self.input_shape,
            "epsilon": self.epsilon,
            "mode": self.mode}

本文系作者原创,转载请先联系作者: 18254275587@163.com