更新時(shí)間:2023-07-21 來源:黑馬程序員 瀏覽量:
ResNet(Residual Network)是由Kaiming He等人提出的深度學(xué)習(xí)神經(jīng)網(wǎng)絡(luò)結(jié)構(gòu),它在2015年的ImageNet圖像識(shí)別競(jìng)賽中取得了非常顯著的成績(jī),引起了廣泛的關(guān)注。ResNet的主要貢獻(xiàn)是解決了深度神經(jīng)網(wǎng)絡(luò)的梯度消失問題,使得可以訓(xùn)練更深的網(wǎng)絡(luò),從而獲得更好的性能。
問題:在傳統(tǒng)的深度神經(jīng)網(wǎng)絡(luò)中,隨著網(wǎng)絡(luò)層數(shù)的增加,梯度在反向傳播過程中逐漸變小,導(dǎo)致淺層網(wǎng)絡(luò)的權(quán)重更新幾乎沒有效果,難以訓(xùn)練。這被稱為梯度消失問題。
ResNet的解決方法:ResNet引入了“殘差塊”(residual block),每個(gè)殘差塊包含了一條“跳躍連接”(shortcut connection),它允許梯度能夠直接穿過塊,從而避免了梯度消失問題。因此,深度網(wǎng)絡(luò)可以通過恒等映射(identity mapping)來學(xué)習(xí)殘差,使得網(wǎng)絡(luò)在增加深度時(shí)反而變得更容易訓(xùn)練。
ResNet結(jié)構(gòu)特點(diǎn):
1.殘差塊:每個(gè)殘差塊由兩個(gè)或三個(gè)卷積層組成,它們的輸出通過跳躍連接與塊的輸入相加,形成殘差(residual)。
2.跳躍連接:跳躍連接允許梯度直接流過塊,有助于避免梯度消失問題。
3.批量歸一化:ResNet中廣泛使用批量歸一化層來加速訓(xùn)練并穩(wěn)定網(wǎng)絡(luò)。
4.殘差塊堆疊:ResNet通過堆疊多個(gè)殘差塊來構(gòu)建深層網(wǎng)絡(luò)。深度可以根據(jù)任務(wù)的復(fù)雜性而自由選擇。
接下來我們看一個(gè)簡(jiǎn)化的ResNet代碼演示(使用TensorFlow):
import tensorflow as tf
from tensorflow.keras import layers, models
# 定義一個(gè)基本的殘差塊
def residual_block(x, filters, downsample=False):
# 如果downsample為True,使用步長(zhǎng)為2的卷積層實(shí)現(xiàn)降采樣
stride = 2 if downsample else 1
# 記錄輸入,以便在跳躍連接時(shí)使用
identity = x
# 第一個(gè)卷積層
x = layers.Conv2D(filters, kernel_size=3, strides=stride, padding='same')(x)
x = layers.BatchNormalization()(x)
x = layers.Activation('relu')(x)
# 第二個(gè)卷積層
x = layers.Conv2D(filters, kernel_size=3, strides=1, padding='same')(x)
x = layers.BatchNormalization()(x)
# 如果進(jìn)行了降采樣,需要對(duì)identity進(jìn)行相應(yīng)處理,保證維度一致
if downsample:
identity = layers.Conv2D(filters, kernel_size=1, strides=stride, padding='same')(identity)
identity = layers.BatchNormalization()(identity)
# 跳躍連接:將卷積層的輸出與輸入相加
x = layers.add([x, identity])
x = layers.Activation('relu')(x)
return x
# 構(gòu)建ResNet網(wǎng)絡(luò)
def ResNet(input_shape, num_classes):
input_img = layers.Input(shape=input_shape)
# 第一個(gè)卷積層
x = layers.Conv2D(64, kernel_size=7, strides=2, padding='same')(input_img)
x = layers.BatchNormalization()(x)
x = layers.Activation('relu')(x)
x = layers.MaxPooling2D(pool_size=3, strides=2, padding='same')(x)
# 堆疊殘差塊組成網(wǎng)絡(luò)
x = residual_block(x, filters=64)
x = residual_block(x, filters=64)
x = residual_block(x, filters=64)
x = residual_block(x, filters=128, downsample=True)
x = residual_block(x, filters=128)
x = residual_block(x, filters=128)
x = residual_block(x, filters=256, downsample=True)
x = residual_block(x, filters=256)
x = residual_block(x, filters=256)
x = residual_block(x, filters=512, downsample=True)
x = residual_block(x, filters=512)
x = residual_block(x, filters=512)
# 全局平均池化
x = layers.GlobalAveragePooling2D()(x)
# 全連接層輸出
x = layers.Dense(num_classes, activation='softmax')(x)
# 創(chuàng)建模型
model = models.Model(inputs=input_img, outputs=x)
return model
# 在這里定義輸入圖像的形狀和類別數(shù)
input_shape = (224, 224, 3)
num_classes = 1000
# 構(gòu)建ResNet模型
model = ResNet(input_shape, num_classes)
model.summary()
請(qǐng)注意,上述代碼是一個(gè)簡(jiǎn)化版本的ResNet網(wǎng)絡(luò),實(shí)際上,ResNet有不同的變體,可以根據(jù)任務(wù)的復(fù)雜性和資源的可用性選擇適合的ResNet結(jié)構(gòu)。