检测头篇 | YOLOv8 更换 SEResNeXtBottleneck 头 | 附详细结构图
from torch import nn
class SEResNeXtBottleneck(nn.Module):
"""
定义一个 SEResNeXt 型 bottleneck 模块。
"""
def __init__(self, in_channels, out_channels, stride, cardinality, bottleneck_width, is_first):
super(SEResNeXtBottleneck, self).__init__()
mid_channels = cardinality * bottleneck_width
self.conv1 = nn.Conv2d(in_channels, mid_channels, kernel_size=1, bias=False)
self.bn1 = nn.BatchNorm2d(mid_channels)
self.conv2 = nn.Conv2d(mid_channels, mid_channels, kernel_size=3, stride=stride, padding=1, groups=cardinality, bias=False)
self.bn2 = nn.BatchNorm2d(mid_channels)
self.conv3 = nn.Conv2d(mid_channels, out_channels, kernel_size=1, bias=False)
self.bn3 = nn.BatchNorm2d(out_channels)
self.shortcut = nn.Sequential()
if not is_first:
self.shortcut = nn.Sequential(
nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride, bias=False),
nn.BatchNorm2d(out_channels)
)
def forward(self, x):
out = self.conv1(x)
out = self.bn1(out)
out = nn.functional.relu(out, inplace=True)
out = self.conv2(out)
out = self.bn2(out)
out = nn.functional.relu(out, inplace=True)
out = self.conv3(out)
out = self.bn3(out)
identity = self.shortcut(x)
out += identity
out = nn.functional.relu(out, inplace=True)
return out
# 使用示例
cardinality = 32
bottleneck_width = 4
stride = 2
in_channels = 64
out_channels = 256
is_first = False
bottleneck = SEResNeXtBottleneck(in_channels, out_channels, stride, cardinality, bottleneck_width, is_first)
input_tensor = torch.randn(1, in_channels, 56, 56)
output_tensor = bottleneck(input_tensor)
print(output_tensor.size())
这段代码定义了一个 SEResNeXtBottleneck 类,它是用于构建深度学习中 ResNeXt 架构的 SEResNeXt 型 bottleneck 模块。它接收输入特征图并通过一系列的卷积和激活函数处理,然后将处理后的特征图和直接从输入特征图 short-cut 的结果相加,最后再进行一次激活函数处理,以提高网络的学习能力和鲁棒性。使用示例展示了如何实例化这个模块并进行前向传播。
评论已关闭