Deep Learning básico con Keras (Parte 4): ResNet

Publicado por Jesús Utrera Burgal el

En este artículo vamos a mostrar la arquitectura ResNet. Ésta fue introducida por Microsoft, ganando la competición ILSVRC (ImageNet Large Scale Visual Recognition Challenge) en el año 2015. En el siguiente enlace se puede acceder al paper:

En este diagrama, podemos ver la arquitectura ResNet:

Diagrama de arquitectura ResNet

La idea, muy resumida, se basa en aumentar el número de capas introduciendo una conexión residual (con una capa identidad). Esta capa pasa a la siguiente directamente, mejorando el proceso de aprendizaje.

CNN tradicional VS CNN con conexión residual

Realizaremos el mismo experimento que en las partes anteriores. Obviaremos los puntos en los que importamos el dataset de CIFAR-100, la configuración básica del entorno del experimento y la importación de las librerías de python, pues son exactamente igual.

Entrenando la arquitectura ResNet

Keras tiene a nuestra disposición ésta arquitectura, pero tiene el problema que, por defecto, el tamaño de las imágenes debe ser mayor a 187 píxeles, por lo que definiremos una arquitectura más pequeña.

def CustomResNet50(include_top=True, input_tensor=None, input_shape=(32,32,3), pooling=None, classes=100):  
    if input_tensor is None:
        img_input = Input(shape=input_shape)
        if not K.is_keras_tensor(input_tensor):
            img_input = Input(tensor=input_tensor, shape=input_shape)
            img_input = input_tensor
    if K.image_data_format() == 'channels_last':
        bn_axis = 3
        bn_axis = 1

    x = ZeroPadding2D(padding=(2, 2), name='conv1_pad')(img_input)

    x = resnet50.conv_block(x, 3, [32, 32, 64], stage=2, block='a')
    x = resnet50.identity_block(x, 3, [32, 32, 64], stage=2, block='b')
    x = resnet50.identity_block(x, 3, [32, 32, 64], stage=2, block='c')

    x = resnet50.conv_block(x, 3, [64, 64, 256], stage=3, block='a', strides=(1, 1))
    x = resnet50.identity_block(x, 3, [64, 64, 256], stage=3, block='b')
    x = resnet50.identity_block(x, 3, [64, 64, 256], stage=3, block='c')

    x = resnet50.conv_block(x, 3, [128, 128, 512], stage=4, block='a')
    x = resnet50.identity_block(x, 3, [128, 128, 512], stage=4, block='b')
    x = resnet50.identity_block(x, 3, [128, 128, 512], stage=4, block='c')
    x = resnet50.identity_block(x, 3, [128, 128, 512], stage=4, block='d')

    x = resnet50.conv_block(x, 3, [256, 256, 1024], stage=5, block='a')
    x = resnet50.identity_block(x, 3, [256, 256, 1024], stage=5, block='b')
    x = resnet50.identity_block(x, 3, [256, 256, 1024], stage=5, block='c')
    x = resnet50.identity_block(x, 3, [256, 256, 1024], stage=5, block='d')
    x = resnet50.identity_block(x, 3, [256, 256, 1024], stage=5, block='e')
    x = resnet50.identity_block(x, 3, [256, 256, 1024], stage=5, block='f')

    x = resnet50.conv_block(x, 3, [512, 512, 2048], stage=6, block='a')
    x = resnet50.identity_block(x, 3, [512, 512, 2048], stage=6, block='b')
    x = resnet50.identity_block(x, 3, [512, 512, 2048], stage=6, block='c')

    x = AveragePooling2D((1, 1), name='avg_pool')(x)

    if include_top:
        x = Flatten()(x)
        x = Dense(classes, activation='softmax', name='fc1000')(x)
        if pooling == 'avg':
            x = GlobalAveragePooling2D()(x)
        elif pooling == 'max':
            x = GlobalMaxPooling2D()(x)

    # Ensure that the model takes into account
    # any potential predecessors of `input_tensor`.
    if input_tensor is not None:
        inputs = get_source_inputs(input_tensor)
        inputs = img_input
    # Create model.
    model = Model(inputs, x, name='resnet50')

    return model

Compilamos como hasta ahora...

def create_custom_resnet50():  
  model = CustomResNet50(include_top=True, input_tensor=None, input_shape=(32,32,3), pooling=None, classes=100)

  return model

custom_resnet50_model = create_custom_resnet50()  
custom_resnet50_model.compile(loss='categorical_crossentropy', optimizer='sgd', metrics=['acc', 'mse'])  

Una vez hecho esto, vamos a ver un resumen del modelo creado.


Layer (type)                    Output Shape         Param #     Connected to  
input_1 (InputLayer)            (None, 32, 32, 3)    0  
conv1_pad (ZeroPadding2D)       (None, 36, 36, 3)    0           input_1[0][0]  
res2a_branch2a (Conv2D)         (None, 18, 18, 32)   128         conv1_pad[0][0]  
bn2a_branch2a (BatchNormalizati (None, 18, 18, 32)   128         res2a_branch2a[0][0]  
activation_1 (Activation)       (None, 18, 18, 32)   0           bn2a_branch2a[0][0]  
res2a_branch2b (Conv2D)         (None, 18, 18, 32)   9248        activation_1[0][0]  
bn2a_branch2b (BatchNormalizati (None, 18, 18, 32)   128         res2a_branch2b[0][0]  
activation_2 (Activation)       (None, 18, 18, 32)   0           bn2a_branch2b[0][0]  
res2a_branch2c (Conv2D)         (None, 18, 18, 64)   2112        activation_2[0][0]  
res2a_branch1 (Conv2D)          (None, 18, 18, 64)   256         conv1_pad[0][0]  
bn2a_branch2c (BatchNormalizati (None, 18, 18, 64)   256         res2a_branch2c[0][0]  
bn2a_branch1 (BatchNormalizatio (None, 18, 18, 64)   256         res2a_branch1[0][0]  
add_1 (Add)                     (None, 18, 18, 64)   0           bn2a_branch2c[0][0]  
activation_3 (Activation)       (None, 18, 18, 64)   0           add_1[0][0]  
res2b_branch2a (Conv2D)         (None, 18, 18, 32)   2080        activation_3[0][0]  
bn2b_branch2a (BatchNormalizati (None, 18, 18, 32)   128         res2b_branch2a[0][0]  
activation_4 (Activation)       (None, 18, 18, 32)   0           bn2b_branch2a[0][0]  
res2b_branch2b (Conv2D)         (None, 18, 18, 32)   9248        activation_4[0][0]  
bn2b_branch2b (BatchNormalizati (None, 18, 18, 32)   128         res2b_branch2b[0][0]  
activation_5 (Activation)       (None, 18, 18, 32)   0           bn2b_branch2b[0][0]  
res2b_branch2c (Conv2D)         (None, 18, 18, 64)   2112        activation_5[0][0]  
bn2b_branch2c (BatchNormalizati (None, 18, 18, 64)   256         res2b_branch2c[0][0]  
add_2 (Add)                     (None, 18, 18, 64)   0           bn2b_branch2c[0][0]  
activation_6 (Activation)       (None, 18, 18, 64)   0           add_2[0][0]  
res2c_branch2a (Conv2D)         (None, 18, 18, 32)   2080        activation_6[0][0]  
bn2c_branch2a (BatchNormalizati (None, 18, 18, 32)   128         res2c_branch2a[0][0]  
activation_7 (Activation)       (None, 18, 18, 32)   0           bn2c_branch2a[0][0]  
res2c_branch2b (Conv2D)         (None, 18, 18, 32)   9248        activation_7[0][0]  
bn2c_branch2b (BatchNormalizati (None, 18, 18, 32)   128         res2c_branch2b[0][0]  
activation_8 (Activation)       (None, 18, 18, 32)   0           bn2c_branch2b[0][0]  
res2c_branch2c (Conv2D)         (None, 18, 18, 64)   2112        activation_8[0][0]  
bn2c_branch2c (BatchNormalizati (None, 18, 18, 64)   256         res2c_branch2c[0][0]  
add_3 (Add)                     (None, 18, 18, 64)   0           bn2c_branch2c[0][0]  
activation_9 (Activation)       (None, 18, 18, 64)   0           add_3[0][0]  
res3a_branch2a (Conv2D)         (None, 18, 18, 64)   4160        activation_9[0][0]  
bn3a_branch2a (BatchNormalizati (None, 18, 18, 64)   256         res3a_branch2a[0][0]  
activation_10 (Activation)      (None, 18, 18, 64)   0           bn3a_branch2a[0][0]  
res3a_branch2b (Conv2D)         (None, 18, 18, 64)   36928       activation_10[0][0]  
bn3a_branch2b (BatchNormalizati (None, 18, 18, 64)   256         res3a_branch2b[0][0]  
activation_11 (Activation)      (None, 18, 18, 64)   0           bn3a_branch2b[0][0]  
res3a_branch2c (Conv2D)         (None, 18, 18, 256)  16640       activation_11[0][0]  
res3a_branch1 (Conv2D)          (None, 18, 18, 256)  16640       activation_9[0][0]  
bn3a_branch2c (BatchNormalizati (None, 18, 18, 256)  1024        res3a_branch2c[0][0]  
bn3a_branch1 (BatchNormalizatio (None, 18, 18, 256)  1024        res3a_branch1[0][0]  
add_4 (Add)                     (None, 18, 18, 256)  0           bn3a_branch2c[0][0]  
activation_12 (Activation)      (None, 18, 18, 256)  0           add_4[0][0]  
res3b_branch2a (Conv2D)         (None, 18, 18, 64)   16448       activation_12[0][0]  
bn3b_branch2a (BatchNormalizati (None, 18, 18, 64)   256         res3b_branch2a[0][0]  
activation_13 (Activation)      (None, 18, 18, 64)   0           bn3b_branch2a[0][0]  
res3b_branch2b (Conv2D)         (None, 18, 18, 64)   36928       activation_13[0][0]  
bn3b_branch2b (BatchNormalizati (None, 18, 18, 64)   256         res3b_branch2b[0][0]  
activation_14 (Activation)      (None, 18, 18, 64)   0           bn3b_branch2b[0][0]  
res3b_branch2c (Conv2D)         (None, 18, 18, 256)  16640       activation_14[0][0]  
bn3b_branch2c (BatchNormalizati (None, 18, 18, 256)  1024        res3b_branch2c[0][0]  
add_5 (Add)                     (None, 18, 18, 256)  0           bn3b_branch2c[0][0]  
activation_15 (Activation)      (None, 18, 18, 256)  0           add_5[0][0]  
res3c_branch2a (Conv2D)         (None, 18, 18, 64)   16448       activation_15[0][0]  
bn3c_branch2a (BatchNormalizati (None, 18, 18, 64)   256         res3c_branch2a[0][0]  
activation_16 (Activation)      (None, 18, 18, 64)   0           bn3c_branch2a[0][0]  
res3c_branch2b (Conv2D)         (None, 18, 18, 64)   36928       activation_16[0][0]  
bn3c_branch2b (BatchNormalizati (None, 18, 18, 64)   256         res3c_branch2b[0][0]  
activation_17 (Activation)      (None, 18, 18, 64)   0           bn3c_branch2b[0][0]  
res3c_branch2c (Conv2D)         (None, 18, 18, 256)  16640       activation_17[0][0]  
bn3c_branch2c (BatchNormalizati (None, 18, 18, 256)  1024        res3c_branch2c[0][0]  
add_6 (Add)                     (None, 18, 18, 256)  0           bn3c_branch2c[0][0]  
activation_18 (Activation)      (None, 18, 18, 256)  0           add_6[0][0]  
res4a_branch2a (Conv2D)         (None, 9, 9, 128)    32896       activation_18[0][0]  
bn4a_branch2a (BatchNormalizati (None, 9, 9, 128)    512         res4a_branch2a[0][0]  
activation_19 (Activation)      (None, 9, 9, 128)    0           bn4a_branch2a[0][0]  
res4a_branch2b (Conv2D)         (None, 9, 9, 128)    147584      activation_19[0][0]  
bn4a_branch2b (BatchNormalizati (None, 9, 9, 128)    512         res4a_branch2b[0][0]  
activation_20 (Activation)      (None, 9, 9, 128)    0           bn4a_branch2b[0][0]  
res4a_branch2c (Conv2D)         (None, 9, 9, 512)    66048       activation_20[0][0]  
res4a_branch1 (Conv2D)          (None, 9, 9, 512)    131584      activation_18[0][0]  
bn4a_branch2c (BatchNormalizati (None, 9, 9, 512)    2048        res4a_branch2c[0][0]  
bn4a_branch1 (BatchNormalizatio (None, 9, 9, 512)    2048        res4a_branch1[0][0]  
add_7 (Add)                     (None, 9, 9, 512)    0           bn4a_branch2c[0][0]  
activation_21 (Activation)      (None, 9, 9, 512)    0           add_7[0][0]  
res4b_branch2a (Conv2D)         (None, 9, 9, 128)    65664       activation_21[0][0]  
bn4b_branch2a (BatchNormalizati (None, 9, 9, 128)    512         res4b_branch2a[0][0]  
activation_22 (Activation)      (None, 9, 9, 128)    0           bn4b_branch2a[0][0]  
res4b_branch2b (Conv2D)         (None, 9, 9, 128)    147584      activation_22[0][0]  
bn4b_branch2b (BatchNormalizati (None, 9, 9, 128)    512         res4b_branch2b[0][0]  
activation_23 (Activation)      (None, 9, 9, 128)    0           bn4b_branch2b[0][0]  
res4b_branch2c (Conv2D)         (None, 9, 9, 512)    66048       activation_23[0][0]  
bn4b_branch2c (BatchNormalizati (None, 9, 9, 512)    2048        res4b_branch2c[0][0]  
add_8 (Add)                     (None, 9, 9, 512)    0           bn4b_branch2c[0][0]  
activation_24 (Activation)      (None, 9, 9, 512)    0           add_8[0][0]  
res4c_branch2a (Conv2D)         (None, 9, 9, 128)    65664       activation_24[0][0]  
bn4c_branch2a (BatchNormalizati (None, 9, 9, 128)    512         res4c_branch2a[0][0]  
activation_25 (Activation)      (None, 9, 9, 128)    0           bn4c_branch2a[0][0]  
res4c_branch2b (Conv2D)         (None, 9, 9, 128)    147584      activation_25[0][0]  
bn4c_branch2b (BatchNormalizati (None, 9, 9, 128)    512         res4c_branch2b[0][0]  
activation_26 (Activation)      (None, 9, 9, 128)    0           bn4c_branch2b[0][0]  
res4c_branch2c (Conv2D)         (None, 9, 9, 512)    66048       activation_26[0][0]  
bn4c_branch2c (BatchNormalizati (None, 9, 9, 512)    2048        res4c_branch2c[0][0]  
add_9 (Add)                     (None, 9, 9, 512)    0           bn4c_branch2c[0][0]  
activation_27 (Activation)      (None, 9, 9, 512)    0           add_9[0][0]  
res4d_branch2a (Conv2D)         (None, 9, 9, 128)    65664       activation_27[0][0]  
bn4d_branch2a (BatchNormalizati (None, 9, 9, 128)    512         res4d_branch2a[0][0]  
activation_28 (Activation)      (None, 9, 9, 128)    0           bn4d_branch2a[0][0]  
res4d_branch2b (Conv2D)         (None, 9, 9, 128)    147584      activation_28[0][0]  
bn4d_branch2b (BatchNormalizati (None, 9, 9, 128)    512         res4d_branch2b[0][0]  
activation_29 (Activation)      (None, 9, 9, 128)    0           bn4d_branch2b[0][0]  
res4d_branch2c (Conv2D)         (None, 9, 9, 512)    66048       activation_29[0][0]  
bn4d_branch2c (BatchNormalizati (None, 9, 9, 512)    2048        res4d_branch2c[0][0]  
add_10 (Add)                    (None, 9, 9, 512)    0           bn4d_branch2c[0][0]  
activation_30 (Activation)      (None, 9, 9, 512)    0           add_10[0][0]  
res5a_branch2a (Conv2D)         (None, 5, 5, 256)    131328      activation_30[0][0]  
bn5a_branch2a (BatchNormalizati (None, 5, 5, 256)    1024        res5a_branch2a[0][0]  
activation_31 (Activation)      (None, 5, 5, 256)    0           bn5a_branch2a[0][0]  
res5a_branch2b (Conv2D)         (None, 5, 5, 256)    590080      activation_31[0][0]  
bn5a_branch2b (BatchNormalizati (None, 5, 5, 256)    1024        res5a_branch2b[0][0]  
activation_32 (Activation)      (None, 5, 5, 256)    0           bn5a_branch2b[0][0]  
res5a_branch2c (Conv2D)         (None, 5, 5, 1024)   263168      activation_32[0][0]  
res5a_branch1 (Conv2D)          (None, 5, 5, 1024)   525312      activation_30[0][0]  
bn5a_branch2c (BatchNormalizati (None, 5, 5, 1024)   4096        res5a_branch2c[0][0]  
bn5a_branch1 (BatchNormalizatio (None, 5, 5, 1024)   4096        res5a_branch1[0][0]  
add_11 (Add)                    (None, 5, 5, 1024)   0           bn5a_branch2c[0][0]  
activation_33 (Activation)      (None, 5, 5, 1024)   0           add_11[0][0]  
res5b_branch2a (Conv2D)         (None, 5, 5, 256)    262400      activation_33[0][0]  
bn5b_branch2a (BatchNormalizati (None, 5, 5, 256)    1024        res5b_branch2a[0][0]  
activation_34 (Activation)      (None, 5, 5, 256)    0           bn5b_branch2a[0][0]  
res5b_branch2b (Conv2D)         (None, 5, 5, 256)    590080      activation_34[0][0]  
bn5b_branch2b (BatchNormalizati (None, 5, 5, 256)    1024        res5b_branch2b[0][0]  
activation_35 (Activation)      (None, 5, 5, 256)    0           bn5b_branch2b[0][0]  
res5b_branch2c (Conv2D)         (None, 5, 5, 1024)   263168      activation_35[0][0]  
bn5b_branch2c (BatchNormalizati (None, 5, 5, 1024)   4096        res5b_branch2c[0][0]  
add_12 (Add)                    (None, 5, 5, 1024)   0           bn5b_branch2c[0][0]  
activation_36 (Activation)      (None, 5, 5, 1024)   0           add_12[0][0]  
res5c_branch2a (Conv2D)         (None, 5, 5, 256)    262400      activation_36[0][0]  
bn5c_branch2a (BatchNormalizati (None, 5, 5, 256)    1024        res5c_branch2a[0][0]  
activation_37 (Activation)      (None, 5, 5, 256)    0           bn5c_branch2a[0][0]  
res5c_branch2b (Conv2D)         (None, 5, 5, 256)    590080      activation_37[0][0]  
bn5c_branch2b (BatchNormalizati (None, 5, 5, 256)    1024        res5c_branch2b[0][0]  
activation_38 (Activation)      (None, 5, 5, 256)    0           bn5c_branch2b[0][0]  
res5c_branch2c (Conv2D)         (None, 5, 5, 1024)   263168      activation_38[0][0]  
bn5c_branch2c (BatchNormalizati (None, 5, 5, 1024)   4096        res5c_branch2c[0][0]  
add_13 (Add)                    (None, 5, 5, 1024)   0           bn5c_branch2c[0][0]  
activation_39 (Activation)      (None, 5, 5, 1024)   0           add_13[0][0]  
res5d_branch2a (Conv2D)         (None, 5, 5, 256)    262400      activation_39[0][0]  
bn5d_branch2a (BatchNormalizati (None, 5, 5, 256)    1024        res5d_branch2a[0][0]  
activation_40 (Activation)      (None, 5, 5, 256)    0           bn5d_branch2a[0][0]  
res5d_branch2b (Conv2D)         (None, 5, 5, 256)    590080      activation_40[0][0]  
bn5d_branch2b (BatchNormalizati (None, 5, 5, 256)    1024        res5d_branch2b[0][0]  
activation_41 (Activation)      (None, 5, 5, 256)    0           bn5d_branch2b[0][0]  
res5d_branch2c (Conv2D)         (None, 5, 5, 1024)   263168      activation_41[0][0]  
bn5d_branch2c (BatchNormalizati (None, 5, 5, 1024)   4096        res5d_branch2c[0][0]  
add_14 (Add)                    (None, 5, 5, 1024)   0           bn5d_branch2c[0][0]  
activation_42 (Activation)      (None, 5, 5, 1024)   0           add_14[0][0]  
res5e_branch2a (Conv2D)         (None, 5, 5, 256)    262400      activation_42[0][0]  
bn5e_branch2a (BatchNormalizati (None, 5, 5, 256)    1024        res5e_branch2a[0][0]  
activation_43 (Activation)      (None, 5, 5, 256)    0           bn5e_branch2a[0][0]  
res5e_branch2b (Conv2D)         (None, 5, 5, 256)    590080      activation_43[0][0]  
bn5e_branch2b (BatchNormalizati (None, 5, 5, 256)    1024        res5e_branch2b[0][0]  
activation_44 (Activation)      (None, 5, 5, 256)    0           bn5e_branch2b[0][0]  
res5e_branch2c (Conv2D)         (None, 5, 5, 1024)   263168      activation_44[0][0]  
bn5e_branch2c (BatchNormalizati (None, 5, 5, 1024)   4096        res5e_branch2c[0][0]  
add_15 (Add)                    (None, 5, 5, 1024)   0           bn5e_branch2c[0][0]  
activation_45 (Activation)      (None, 5, 5, 1024)   0           add_15[0][0]  
res5f_branch2a (Conv2D)         (None, 5, 5, 256)    262400      activation_45[0][0]  
bn5f_branch2a (BatchNormalizati (None, 5, 5, 256)    1024        res5f_branch2a[0][0]  
activation_46 (Activation)      (None, 5, 5, 256)    0           bn5f_branch2a[0][0]  
res5f_branch2b (Conv2D)         (None, 5, 5, 256)    590080      activation_46[0][0]  
bn5f_branch2b (BatchNormalizati (None, 5, 5, 256)    1024        res5f_branch2b[0][0]  
activation_47 (Activation)      (None, 5, 5, 256)    0           bn5f_branch2b[0][0]  
res5f_branch2c (Conv2D)         (None, 5, 5, 1024)   263168      activation_47[0][0]  
bn5f_branch2c (BatchNormalizati (None, 5, 5, 1024)   4096        res5f_branch2c[0][0]  
add_16 (Add)                    (None, 5, 5, 1024)   0           bn5f_branch2c[0][0]  
activation_48 (Activation)      (None, 5, 5, 1024)   0           add_16[0][0]  
res6a_branch2a (Conv2D)         (None, 3, 3, 512)    524800      activation_48[0][0]  
bn6a_branch2a (BatchNormalizati (None, 3, 3, 512)    2048        res6a_branch2a[0][0]  
activation_49 (Activation)      (None, 3, 3, 512)    0           bn6a_branch2a[0][0]  
res6a_branch2b (Conv2D)         (None, 3, 3, 512)    2359808     activation_49[0][0]  
bn6a_branch2b (BatchNormalizati (None, 3, 3, 512)    2048        res6a_branch2b[0][0]  
activation_50 (Activation)      (None, 3, 3, 512)    0           bn6a_branch2b[0][0]  
res6a_branch2c (Conv2D)         (None, 3, 3, 2048)   1050624     activation_50[0][0]  
res6a_branch1 (Conv2D)          (None, 3, 3, 2048)   2099200     activation_48[0][0]  
bn6a_branch2c (BatchNormalizati (None, 3, 3, 2048)   8192        res6a_branch2c[0][0]  
bn6a_branch1 (BatchNormalizatio (None, 3, 3, 2048)   8192        res6a_branch1[0][0]  
add_17 (Add)                    (None, 3, 3, 2048)   0           bn6a_branch2c[0][0]  
activation_51 (Activation)      (None, 3, 3, 2048)   0           add_17[0][0]  
res6b_branch2a (Conv2D)         (None, 3, 3, 512)    1049088     activation_51[0][0]  
bn6b_branch2a (BatchNormalizati (None, 3, 3, 512)    2048        res6b_branch2a[0][0]  
activation_52 (Activation)      (None, 3, 3, 512)    0           bn6b_branch2a[0][0]  
res6b_branch2b (Conv2D)         (None, 3, 3, 512)    2359808     activation_52[0][0]  
bn6b_branch2b (BatchNormalizati (None, 3, 3, 512)    2048        res6b_branch2b[0][0]  
activation_53 (Activation)      (None, 3, 3, 512)    0           bn6b_branch2b[0][0]  
res6b_branch2c (Conv2D)         (None, 3, 3, 2048)   1050624     activation_53[0][0]  
bn6b_branch2c (BatchNormalizati (None, 3, 3, 2048)   8192        res6b_branch2c[0][0]  
add_18 (Add)                    (None, 3, 3, 2048)   0           bn6b_branch2c[0][0]  
activation_54 (Activation)      (None, 3, 3, 2048)   0           add_18[0][0]  
res6c_branch2a (Conv2D)         (None, 3, 3, 512)    1049088     activation_54[0][0]  
bn6c_branch2a (BatchNormalizati (None, 3, 3, 512)    2048        res6c_branch2a[0][0]  
activation_55 (Activation)      (None, 3, 3, 512)    0           bn6c_branch2a[0][0]  
res6c_branch2b (Conv2D)         (None, 3, 3, 512)    2359808     activation_55[0][0]  
bn6c_branch2b (BatchNormalizati (None, 3, 3, 512)    2048        res6c_branch2b[0][0]  
activation_56 (Activation)      (None, 3, 3, 512)    0           bn6c_branch2b[0][0]  
res6c_branch2c (Conv2D)         (None, 3, 3, 2048)   1050624     activation_56[0][0]  
bn6c_branch2c (BatchNormalizati (None, 3, 3, 2048)   8192        res6c_branch2c[0][0]  
add_19 (Add)                    (None, 3, 3, 2048)   0           bn6c_branch2c[0][0]  
activation_57 (Activation)      (None, 3, 3, 2048)   0           add_19[0][0]  
avg_pool (AveragePooling2D)     (None, 3, 3, 2048)   0           activation_57[0][0]  
flatten_1 (Flatten)             (None, 18432)        0           avg_pool[0][0]  
fc1000 (Dense)                  (None, 100)          1843300     flatten_1[0][0]  
Total params: 25,461,700  
Trainable params: 25,407,812  
Non-trainable params: 53,888  

Recordemos que la arquitectura VGG-16 tenía aproximadamente 34 millones de parámetros a entrenar. Esto quiere decir que hemos aumentado la profundidad pero hemos reducido el número de parámetros a entrenar.

Bien, dicho esto, pasamos a entrenar el modelo.

crn50 =, y=y_train, batch_size=32, epochs=10, verbose=1, validation_data=(x_test, y_test), shuffle=True)

Train on 50000 samples, validate on 10000 samples  
Epoch 1/10  
50000/50000 [==============================] - 441s 9ms/step - loss: 4.5655 - acc: 0.0817 - mean_squared_error: 0.0101 - val_loss: 4.2085 - val_acc: 0.1228 - val_mean_squared_error: 0.0099  
Epoch 2/10  
 50000/50000 [==============================] - 434s 9ms/step - loss: 4.1448 - acc: 0.1348 - mean_squared_error: 0.0098 - val_loss: 4.2032 - val_acc: 0.1236 - val_mean_squared_error: 0.0099
Epoch 3/10  
 50000/50000 [==============================] - 433s 9ms/step - loss: 4.2682 - acc: 0.1146 - mean_squared_error: 0.0099 - val_loss: 4.3306 - val_acc: 0.1066 - val_mean_squared_error: 0.0100
Epoch 4/10  
 50000/50000 [==============================] - 434s 9ms/step - loss: 4.1581 - acc: 0.1340 - mean_squared_error: 0.0098 - val_loss: 4.1405 - val_acc: 0.1384 - val_mean_squared_error: 0.0098
Epoch 5/10  
 50000/50000 [==============================] - 431s 9ms/step - loss: 3.9395 - acc: 0.1653 - mean_squared_error: 0.0096 - val_loss: 3.8838 - val_acc: 0.1718 - val_mean_squared_error: 0.0095
Epoch 6/10  
 50000/50000 [==============================] - 432s 9ms/step - loss: 3.9598 - acc: 0.1698 - mean_squared_error: 0.0096 - val_loss: 4.0047 - val_acc: 0.1608 - val_mean_squared_error: 0.0096
Epoch 7/10  
 50000/50000 [==============================] - 433s 9ms/step - loss: 3.8715 - acc: 0.1797 - mean_squared_error: 0.0095 - val_loss: 4.2620 - val_acc: 0.1184 - val_mean_squared_error: 0.0099
Epoch 8/10  
 50000/50000 [==============================] - 434s 9ms/step - loss: 3.9661 - acc: 0.1666 - mean_squared_error: 0.0096 - val_loss: 3.8181 - val_acc: 0.1898 - val_mean_squared_error: 0.0095
Epoch 9/10  
 50000/50000 [==============================] - 434s 9ms/step - loss: 3.8110 - acc: 0.1901 - mean_squared_error: 0.0095 - val_loss: 3.7521 - val_acc: 0.1966 - val_mean_squared_error: 0.0094
Epoch 10/10  
 50000/50000 [==============================] - 432s 9ms/step - loss: 3.7247 - acc: 0.2048 - mean_squared_error: 0.0094 - val_loss: 3.8206 - val_acc: 0.1929 - val_mean_squared_error: 0.0095

Veamos las métricas obtenidas para el entrenamiento y validación gráficamente.

plt.xticks(np.arange(0, 11, 2.0))  
plt.rcParams['figure.figsize'] = (8, 6)  
plt.xlabel("Num of Epochs")  
plt.title("Training Accuracy vs Validation Accuracy")  

plt.xticks(np.arange(0, 11, 2.0))  
plt.rcParams['figure.figsize'] = (8, 6)  
plt.xlabel("Num of Epochs")  
plt.title("Training Loss vs Validation Loss")  

El entrenamiento ha dado muy buenos resultados y ha generalizado bien (0,0119).

Matriz de confusión

Pasemos ahora a ver la matriz de confusión y las métricas de Accuracy, Recall y F1-score.

Vamos a hacer una predicción sobre el dataset de validación y, a partir de ésta, generamos la matriz de confusión y mostramos las métricas mencionadas anteriormente.

crn50_pred = custom_resnet50_model.predict(x_test, batch_size=32, verbose=1)  
crn50_predicted = np.argmax(crn50_pred, axis=1)

crn50_cm = confusion_matrix(np.argmax(y_test, axis=1), crn50_predicted)

# Visualizing of confusion matrix
crn50_df_cm = pd.DataFrame(crn50_cm, range(100), range(100))  
plt.figure(figsize = (20,14))  
sn.set(font_scale=1.4) #for label size  
sn.heatmap(crn50_df_cm, annot=True, annot_kws={"size": 12}) # font size  
Matriz de confusión

Y por último, mostramos las métricas

crn50_report = classification_report(np.argmax(y_test, axis=1), crn50_predicted)  

             precision    recall  f1-score   support

          0       0.46      0.32      0.38       100
          1       0.25      0.17      0.20       100
          2       0.17      0.09      0.12       100
          3       0.05      0.62      0.09       100
          4       0.18      0.06      0.09       100
          5       0.25      0.05      0.08       100
          6       0.11      0.14      0.12       100
          7       0.15      0.12      0.13       100
          8       0.21      0.20      0.20       100
          9       0.49      0.21      0.29       100
         10       0.11      0.03      0.05       100
         11       0.08      0.05      0.06       100
         12       0.38      0.13      0.19       100
         13       0.23      0.10      0.14       100
         14       0.18      0.05      0.08       100
         15       0.14      0.06      0.08       100
         16       0.19      0.24      0.21       100
         17       0.40      0.19      0.26       100
         18       0.19      0.24      0.21       100
         19       0.20      0.22      0.21       100
         20       0.42      0.31      0.36       100
         21       0.31      0.23      0.26       100
         22       0.35      0.09      0.14       100
         23       0.36      0.37      0.37       100
         24       0.31      0.49      0.38       100
         25       0.17      0.03      0.05       100
         26       0.43      0.06      0.11       100
         27       0.11      0.03      0.05       100
         28       0.31      0.35      0.33       100
         29       0.12      0.10      0.11       100
         30       0.27      0.33      0.30       100
         31       0.11      0.09      0.10       100
         32       0.22      0.20      0.21       100
         33       0.23      0.30      0.26       100
         34       0.17      0.05      0.08       100
         35       0.09      0.02      0.03       100
         36       0.10      0.23      0.14       100
         37       0.15      0.16      0.16       100
         38       0.08      0.24      0.12       100
         39       0.23      0.18      0.20       100
         40       0.26      0.20      0.22       100
         41       0.45      0.49      0.47       100
         42       0.12      0.17      0.14       100
         43       0.11      0.02      0.03       100
         44       0.14      0.09      0.11       100
         45       0.08      0.01      0.02       100
         46       0.07      0.29      0.12       100
         47       0.55      0.18      0.27       100
         48       0.23      0.31      0.26       100
         49       0.27      0.23      0.25       100
         50       0.12      0.05      0.07       100
         51       0.28      0.09      0.14       100
         52       0.47      0.62      0.54       100
         53       0.25      0.13      0.17       100
         54       0.18      0.25      0.21       100
         55       0.00      0.00      0.00       100
         56       0.27      0.27      0.27       100
         57       0.27      0.11      0.16       100
         58       0.15      0.41      0.22       100
         59       0.18      0.10      0.13       100
         60       0.41      0.63      0.50       100
         61       0.33      0.32      0.32       100
         62       0.15      0.07      0.09       100
         63       0.31      0.26      0.28       100
         64       0.11      0.11      0.11       100
         65       0.15      0.11      0.13       100
         66       0.10      0.06      0.08       100
         67       0.15      0.15      0.15       100
         68       0.37      0.66      0.47       100
         69       0.38      0.25      0.30       100
         70       0.21      0.04      0.07       100
         71       0.27      0.54      0.36       100
         72       0.20      0.01      0.02       100
         73       0.30      0.21      0.25       100
         74       0.14      0.15      0.14       100
         75       0.30      0.29      0.29       100
         76       0.40      0.40      0.40       100
         77       0.13      0.14      0.13       100
         78       0.15      0.08      0.10       100
         79       0.14      0.05      0.07       100
         80       0.08      0.05      0.06       100
         81       0.14      0.11      0.12       100
         82       0.37      0.24      0.29       100
         83       0.08      0.02      0.03       100
         84       0.10      0.11      0.10       100
         85       0.23      0.39      0.29       100
         86       0.36      0.21      0.26       100
         87       0.21      0.19      0.20       100
         88       0.05      0.06      0.05       100
         89       0.24      0.18      0.20       100
         90       0.21      0.24      0.22       100
         91       0.33      0.31      0.32       100
         92       0.11      0.11      0.11       100
         93       0.16      0.10      0.12       100
         94       0.38      0.26      0.31       100
         95       0.21      0.50      0.30       100
         96       0.22      0.23      0.22       100
         97       0.10      0.18      0.13       100
         98       0.12      0.02      0.03       100
         99       0.24      0.08      0.12       100

avg / total       0.22      0.19      0.19     10000  

Curva ROC (tasas de verdaderos positivos y falsos positivos)

Vamos a codificar la curva ROC.

from sklearn.datasets import make_classification  
from sklearn.preprocessing import label_binarize  
from scipy import interp  
from itertools import cycle

n_classes = 100

from sklearn.metrics import roc_curve, auc

# Plot linewidth.
lw = 2

# Compute ROC curve and ROC area for each class
fpr = dict()  
tpr = dict()  
roc_auc = dict()  
for i in range(n_classes):  
    fpr[i], tpr[i], _ = roc_curve(y_test[:, i], crn50_pred[:, i])
    roc_auc[i] = auc(fpr[i], tpr[i])

# Compute micro-average ROC curve and ROC area
fpr["micro"], tpr["micro"], _ = roc_curve(y_test.ravel(), crn50_pred.ravel())  
roc_auc["micro"] = auc(fpr["micro"], tpr["micro"])

# Compute macro-average ROC curve and ROC area

# First aggregate all false positive rates
all_fpr = np.unique(np.concatenate([fpr[i] for i in range(n_classes)]))

# Then interpolate all ROC curves at this points
mean_tpr = np.zeros_like(all_fpr)  
for i in range(n_classes):  
    mean_tpr += interp(all_fpr, fpr[i], tpr[i])

# Finally average it and compute AUC
mean_tpr /= n_classes

fpr["macro"] = all_fpr  
tpr["macro"] = mean_tpr  
roc_auc["macro"] = auc(fpr["macro"], tpr["macro"])

# Plot all ROC curves
plt.plot(fpr["micro"], tpr["micro"],  
         label='micro-average ROC curve (area = {0:0.2f})'
         color='deeppink', linestyle=':', linewidth=4)

plt.plot(fpr["macro"], tpr["macro"],  
         label='macro-average ROC curve (area = {0:0.2f})'
         color='navy', linestyle=':', linewidth=4)

colors = cycle(['aqua', 'darkorange', 'cornflowerblue'])  
for i, color in zip(range(n_classes-97), colors):  
    plt.plot(fpr[i], tpr[i], color=color, lw=lw,
             label='ROC curve of class {0} (area = {1:0.2f})'
             ''.format(i, roc_auc[i]))

plt.plot([0, 1], [0, 1], 'k--', lw=lw)  
plt.xlim([0.0, 1.0])  
plt.ylim([0.0, 1.05])  
plt.xlabel('False Positive Rate')  
plt.ylabel('True Positive Rate')  
plt.title('Some extension of Receiver operating characteristic to multi-class')  
plt.legend(loc="lower right")

# Zoom in view of the upper left corner.
plt.xlim(0, 0.2)  
plt.ylim(0.8, 1)  
plt.plot(fpr["micro"], tpr["micro"],  
         label='micro-average ROC curve (area = {0:0.2f})'
         color='deeppink', linestyle=':', linewidth=4)

plt.plot(fpr["macro"], tpr["macro"],  
         label='macro-average ROC curve (area = {0:0.2f})'
         color='navy', linestyle=':', linewidth=4)

colors = cycle(['aqua', 'darkorange', 'cornflowerblue'])  
for i, color in zip(range(10), colors):  
    plt.plot(fpr[i], tpr[i], color=color, lw=lw,
             label='ROC curve of class {0} (area = {1:0.2f})'
             ''.format(i, roc_auc[i]))

plt.plot([0, 1], [0, 1], 'k--', lw=lw)  
plt.xlabel('False Positive Rate')  
plt.ylabel('True Positive Rate')  
plt.title('Some extension of Receiver operating characteristic to multi-class')  
plt.legend(loc="lower right")  

El resultado para tres clases se muestra en los siguientes gráficos.

Curva ROC para 3 clases
Zoom de la Curva ROC para 10 clases

Salvaremos los datos del histórico de entrenamiento para compararlos con otros modelos. Además, vamos a salvar el modelo con los pesos entrenados para usarlos en el futuro.

#Modelo + '/crn50.h5')

with open(path_base + '/crn50_history.txt', 'wb') as file_pi:  
  pickle.dump(crn50.history, file_pi)

A continuación, vamos a comparar las métricas con los modelos anteriores (obviaremos el código que carga los datos de dichos modelos).

plt.xticks(np.arange(0, 11, 2.0))  
plt.rcParams['figure.figsize'] = (8, 6)  
plt.xlabel("Num of Epochs")  
plt.title("Simple NN Accuracy vs simple CNN Accuracy")  
plt.legend(['simple NN','CNN','VGG 16','VGG 19','Custom VGG','Custom ResNet'])  
Simple NN Vs CNN accuracy
plt.xticks(np.arange(0, 11, 2.0))  
plt.rcParams['figure.figsize'] = (8, 6)  
plt.xlabel("Num of Epochs")  
plt.title("Simple NN Loss vs simple CNN Loss")  
plt.legend(['simple NN','CNN','VGG 16','VGG 19','Custom VGG','Custom ResNet'])  
Simple NN Vs CNN loss
plt.xticks(np.arange(0, 11, 2.0))  
plt.rcParams['figure.figsize'] = (8, 6)  
plt.xlabel("Num of Epochs")  
plt.ylabel("Mean Squared Error")  
plt.title("Simple NN MSE vs simple CNN MSE")  
plt.legend(['simple NN','CNN','VGG 16','VGG 19','Custom VGG','Custom ResNet'])  
Simple NN Vs CNN MSE

Conclusión sobre el experimento

Como se puede ver, la arquitectura marca un punto de inflexión. No sólo porque sea de los mejores resultados que las anteriores arquitecturas, sino también en los tiempos de entrenamiento, ya que permite aumentar las capas con un tiempo aceptable; y también en el número de parámetros, que se ha reducido considerablemente respecto a la arquitectura VGG.

En el siguiente artículo, presentaremos la arquitectura: DenseNet. ¡Síguenos en Twitter para estar al día de los próximos posts de esta serie y mucho más!


Jesús Utrera Burgal

Desarrollador .NET por más de 10 años, en los últimos años me he adentrado en el mundo de Machine Learning, concretamente en el área de Supervised Deep Learning.