图像分割代码分析
UNetPlusPlus 图像分割代码分析
训练代码与解释
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import transforms
from my_dataset import ImageSegmentationDataset # 自定义数据集
from NestedUNet import NestedUNet # 模型定义文件
# 定义超参数
batch_size = 1
learning_rate = 1e-4
num_epochs = 200
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)
# 计算新尺寸,将原始尺寸除以 2
new_height = 2048 // 2
new_width = 3072 // 2
# 数据预处理和数据增强
transform = transforms.Compose([
transforms.Resize((new_height, new_width)), # 将图像大小调整为原始尺寸的一半
transforms.ToTensor() # 转换为 PyTorch 张量
])
# 加载数据
train_dataset = ImageSegmentationDataset(image_dir='./dataset/train/images',
mask_dir='./dataset/train/masks',
transform=transform)
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
# 初始化模型、损失函数、优化器
model = NestedUNet(num_classes=2, input_channels=3).to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=learning_rate)
# 训练循环
for epoch in range(num_epochs):
model.train()
epoch_loss = 0.0
for images, masks in train_loader:
images, masks = images.to(device), masks.to(device)
# 确保目标张量的形状为 [batch_size, height, width]
masks = torch.squeeze(masks, dim=1) # 去除通道维度
optimizer.zero_grad()
outputs = model(images)
loss = criterion(outputs, masks)
loss.backward()
optimizer.step()
print(f'Epoch [{epoch + 1}/{num_epochs}], Loss: {loss.item():.4f}')
# 保存训练好的模型
torch.save(model.state_dict(), './model.pth')
1. 数据预处理和加载
transform = transforms.Compose([
transforms.Resize((new_height, new_width)), # 将图像大小调整为原始尺寸的一半
transforms.ToTensor() # 转换为 PyTorch 张量
])
- Resize: 将图像和掩码调整为新的尺寸
(new_height, new_width)
,这里是对原始尺寸(2048, 3072)
进行缩小。 - ToTensor: 将图像和掩码转换为 PyTorch 张量,并将像素值归一化到[0, 1]范围。
train_dataset = ImageSegmentationDataset(image_dir='./dataset/train/images',
mask_dir='./dataset/train/masks',
transform=transform)
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
- ImageSegmentationDataset: 自定义的数据集类,负责加载图像和对应的掩码。
- DataLoader: 将数据集包装成可迭代的 DataLoader,设置 batch size 和 shuffle。
2. 模型、损失函数和优化器的初始化
model = NestedUNet(num_classes=2, input_channels=3).to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=learning_rate)
- NestedUNet: 自定义的神经网络模型,用于图像分割,输入通道数为 3(RGB 图像),输出类别数为 2。
- CrossEntropyLoss: 适用于多类分类任务的损失函数,常用于图像分割。
- Adam 优化器: 用于更新网络参数。
3. 训练循环
for epoch in range(num_epochs):
model.train()
epoch_loss = 0.0
for images, masks in train_loader:
images, masks = images.to(device), masks.to(device)
# 确保目标张量的形状为 [batch_size, height, width]
masks = torch.squeeze(masks, dim=1) # 去除通道维度
optimizer.zero_grad()
outputs = model(images)
loss = criterion(outputs, masks)
loss.backward()
optimizer.step()
print(f'Epoch [{epoch + 1}/{num_epochs}], Loss: {loss.item():.4f}')
- model.train(): 将模型设置为训练模式,启用 dropout 和 batch normalization。
- images, masks = images.to(device), masks.to(device): 将数据转移到 GPU 或 CPU。
- masks = torch.squeeze(masks, dim=1): 这是关键一步,解释如下。
4. 关于通道处理的详细解释
在图像分割任务中:
- 输入图像通常是三维的,形状为
[batch_size, channels, height, width]
,例如[1, 3, 1024, 1536]
。 - 掩码(mask)通常是四维的,但通道数为 1,形状为
[batch_size, 1, height, width]
,例如[1, 1, 1024, 1536]
。
然而,nn.CrossEntropyLoss
函数要求目标掩码的形状为 [batch_size, height, width]
,即不包含通道维度。
因此,需要使用 torch.squeeze
函数去除掩码的通道维度:
masks = torch.squeeze(masks, dim=1)
这将掩码的形状从 [batch_size, 1, height, width]
变为 [batch_size, height, width]
,满足损失函数的要求。
5. 模型输出与损失计算
- outputs = model(images): 模型输出形状为
[batch_size, num_classes, height, width]
,例如[1, 2, 1024, 1536]
。 - loss = criterion(outputs, masks): 计算预测结果与真实掩码之间的交叉熵损失。
6. 模型保存
torch.save(model.state_dict(), './model.pth')
- 保存模型的参数到文件
model.pth
,方便后续加载和推理。
这段代码的主要功能是加载一个预训练的 NestedUNet 模型,使用它对指定目录下的图像进行分割,并将结果保存到输出目录。代码的执行流程如下:
推理代码与解释
import os
import torch
import numpy as np
from PIL import Image
from torchvision import transforms
from NestedUNet import NestedUNet # 模型定义文件
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")
# 加载模型
def load_model(model, path):
if not os.path.exists(path):
raise FileNotFoundError(f"Model file not found: {path}")
model.load_state_dict(torch.load(path, map_location=device))
model.eval()
return model
# 进行推理
def segment_images(model, image_dir, output_dir):
# 计算新尺寸,将原始尺寸除以 2
new_height = 2048 // 2
new_width = 3072 // 2
# 数据预处理和数据增强
transform = transforms.Compose([
transforms.Resize((new_height, new_width)), # 将图像大小调整为原始尺寸的一半
transforms.ToTensor() # 转换为 PyTorch 张量
])
os.makedirs(output_dir, exist_ok=True)
for filename in os.listdir(image_dir):
if filename.endswith(('.png', '.jpg', '.jpeg')):
filepath = os.path.join(image_dir, filename)
image = Image.open(filepath).convert('RGB')
input_tensor = transform(image).unsqueeze(0).to(device) # 增加批量维度
with torch.no_grad():
outputs = model(input_tensor)
prediction = torch.argmax(outputs, dim=1).squeeze(0) # 获取分割结果
# 保存分割结果
output_filename = filename.split('.')[0] + '_segmentation.png'
output_path = os.path.join(output_dir, output_filename)
# 将类别值映射到 0-255 范围
pred_img = prediction.cpu().numpy().astype(np.uint8) * 255
Image.fromarray(pred_img).save(output_path)
# 主执行代码
if __name__ == "__main__":
model = NestedUNet(num_classes=2, input_channels=3).to(device)
model = load_model(model, './model.pth') # 加载预训练模型
# 定义输入目录和输出目录
input_dirs = [
'./dataset/1-2000',
'./dataset/2001-4000',
'./dataset/4001-6000',
'./dataset/6001-8000',
'./dataset/8001-9663'
]
base_output_dir = './dataset/segmentation_results' # 基础输出结果目录
for input_dir in input_dirs:
output_dir = os.path.join(base_output_dir, os.path.basename(input_dir))
segment_images(model, input_dir, output_dir)
print(f"Segmentation results saved to: {base_output_dir}")
1. 设备选择
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")
- 根据机器是否有可用的 GPU(通过
torch.cuda.is_available()
检查)来选择计算设备。如果有 GPU 可用,代码会选择 GPU,否则使用 CPU。
2. 加载模型
def load_model(model, path):
if not os.path.exists(path):
raise FileNotFoundError(f"Model file not found: {path}")
model.load_state_dict(torch.load(path, map_location=device))
model.eval()
return model
load_model
函数:- 检查路径中是否有模型文件。
- 使用
torch.load()
加载预训练的模型参数。 - 加载后,调用
model.eval()
将模型设置为评估模式(禁用 dropout 等操作)。 - 该函数返回加载完权重的模型。
3. 进行图像分割推理
def segment_images(model, image_dir, output_dir):
new_height = 2048 // 2
new_width = 3072 // 2
- 设置目标图像的大小,即将原始图像的高和宽各缩小一半(
2048 // 2
和3072 // 2
)。
数据预处理
transform = transforms.Compose([
transforms.Resize((new_height, new_width)), # 调整图像大小
transforms.ToTensor() # 转换为 PyTorch 张量
])
- 图像通过
Resize
变换调整为新尺寸。 - 然后通过
ToTensor()
转换为 PyTorch 张量格式,方便输入到模型。
处理每张图像
for filename in os.listdir(image_dir):
if filename.endswith(('.png', '.jpg', '.jpeg')):
filepath = os.path.join(image_dir, filename)
image = Image.open(filepath).convert('RGB')
input_tensor = transform(image).unsqueeze(0).to(device) # 增加批量维度
- 遍历
image_dir
目录中的所有图像文件(支持.png
、.jpg
和.jpeg
格式)。 - 对每张图像进行读取并转换为 RGB 模式(即使是灰度图也会被处理为 RGB 图像)。
- 使用预处理
transform
转换为张量,并添加一个批量维度(unsqueeze(0)
),使形状变为[1, C, H, W]
(适配模型输入)。
模型推理
with torch.no_grad():
outputs = model(input_tensor)
prediction = torch.argmax(outputs, dim=1).squeeze(0) # 获取分割结果
- 使用
torch.no_grad()
禁用梯度计算,节省内存和加速推理。 model(input_tensor)
会返回模型的输出(每个像素的类别概率)。torch.argmax(outputs, dim=1)
:对每个像素,取类别概率最大的一项作为预测类别。squeeze(0)
:去除批量维度,得到的prediction
形状为[H, W]
。
保存分割结果
output_filename = filename.split('.')[0] + '_segmentation.png'
output_path = os.path.join(output_dir, output_filename)
pred_img = prediction.cpu().numpy().astype(np.uint8) * 255
Image.fromarray(pred_img).save(output_path)
output_filename
:为每个输出图像文件命名,格式为原始文件名加_segmentation.png
。prediction.cpu().numpy()
:将预测结果从 GPU 移到 CPU,并转换为 NumPy 数组。astype(np.uint8) * 255
:将预测类别(0 或 1)映射到灰度值(0 或 255),使得结果可以保存为黑白图片。- 使用 Pillow 将
pred_img
保存为 PNG 格式。
4. 主执行代码
if __name__ == "__main__":
model = NestedUNet(num_classes=2, input_channels=3).to(device)
model = load_model(model, './model.pth') # 加载预训练模型
input_dirs = [
'./dataset/1-2000',
'./dataset/2001-4000',
'./dataset/4001-6000',
'./dataset/6001-8000',
'./dataset/8001-9663'
]
base_output_dir = './dataset/segmentation_results' # 基础输出结果目录
for input_dir in input_dirs:
output_dir = os.path.join(base_output_dir, os.path.basename(input_dir))
segment_images(model, input_dir, output_dir)
print(f"Segmentation results saved to: {base_output_dir}")
- 在主程序中,首先加载
NestedUNet
模型并加载权重。 - 定义一个包含多个子文件夹路径(
input_dirs
)的列表,每个文件夹包含待分割的图像。 - 为每个输入文件夹生成对应的输出文件夹,将分割结果保存到这些输出文件夹。
- 最后输出保存结果的目录路径。
数据预处理代码与解释
import os
import numpy as np
import torch
from PIL import Image
class ImageSegmentationDataset:
def __init__(self, image_dir, mask_dir, transform=None):
self.image_dir = image_dir
self.mask_dir = mask_dir
self.transform = transform
self.image_files = sorted(os.listdir(image_dir)) # 获取图片文件列表并排序
def __getitem__(self, idx):
# 获取图像文件名
image_file = self.image_files[idx]
image_path = os.path.join(self.image_dir, image_file)
# 构建掩膜文件名,假设掩膜文件以 "_mask" 结尾
mask_file = image_file.replace(".jpg", "_mask.png")
mask_path = os.path.join(self.mask_dir, mask_file)
# 加载图像和掩膜
image = Image.open(image_path).convert('RGB')
mask = Image.open(mask_path).convert('L') # 灰度图
# 如果有 transform(数据增强等),应用 transform
if self.transform:
image = self.transform(image)
mask = self.transform(mask)
mask = torch.tensor(np.array(mask, dtype=np.int64))
return image, mask
def __len__(self):
# 返回数据集中图像文件的数量
return len(self.image_files)
__init__
构造函数
def __init__(self, image_dir, mask_dir, transform=None):
self.image_dir = image_dir
self.mask_dir = mask_dir
self.transform = transform
self.image_files = sorted(os.listdir(image_dir)) # 获取图片文件列表并排序
image_dir
:图像存放的目录路径。mask_dir
:掩膜图像存放的目录路径。每张图像都会有一个对应的掩膜图像,掩膜是标注了目标区域的图像。transform
:如果有数据预处理或数据增强的操作,可以传递给transform
。例如,可能会进行图像大小调整、归一化等。image_files
:获取image_dir
中的所有文件名并进行排序,确保图像的顺序与掩膜的顺序一致。
__getitem__
方法
def __getitem__(self, idx):
# 获取图像文件名
image_file = self.image_files[idx]
image_path = os.path.join(self.image_dir, image_file)
# 构建掩膜文件名,假设掩膜文件以 "_mask" 结尾
mask_file = image_file.replace(".jpg", "_mask.png")
mask_path = os.path.join(self.mask_dir, mask_file)
# 加载图像和掩膜
image = Image.open(image_path).convert('RGB')
mask = Image.open(mask_path).convert('L') # 灰度图
idx
:传入的索引,表示要加载数据集中的哪一张图片及其对应的掩膜。image_file
:根据idx
获取当前图像文件的名称。image_path
:根据文件名构建图像的完整路径。mask_file
:假设掩膜图像与原图像的文件名一致,只是在原文件名的基础上加上_mask
后缀(假设原图为.jpg
,掩膜图像为.png
)。可以根据需要修改这个规则。mask_path
:根据掩膜文件名构建掩膜图像的完整路径。加载图像和掩膜:
- 使用
Pillow
的Image.open()
函数加载图像,并使用.convert('RGB')
确保图像是三通道的 RGB 格式。 - 掩膜是灰度图,所以加载时使用
.convert('L')
,使其成为单通道的灰度图像。
- 使用
应用预处理操作
if self.transform:
image = self.transform(image)
mask = self.transform(mask)
- 如果传递了
transform
(例如数据增强或预处理操作),则对图像和掩膜应用该操作。通常,这里会进行如调整大小、归一化、数据增强等操作。
转换掩膜为 PyTorch 张量
mask = torch.tensor(np.array(mask, dtype=np.int64))
- 将掩膜图像从
Pillow
图像对象转换为 NumPy 数组。 - 然后将 NumPy 数组转换为 PyTorch 张量,类型为
int64
。这里使用int64
是因为通常分割任务的标签是整数类型(比如每个像素对应的类别 ID)。
__len__
方法
def __len__(self):
# 返回数据集中图像文件的数量
return len(self.image_files)
- 该方法返回数据集中图像文件的数量。PyTorch 数据集类需要实现该方法,以便能够知道数据集的大小。
这段代码实现了一个名为 Nested U-Net 的深度学习模型,主要用于 图像分割 任务。Nested U-Net 是在传统 U-Net 基础上改进的一种结构,通过增加嵌套跳跃连接(nested skip connections),进一步提升了模型的分割精度。下面我会详细解释代码中的各个部分,特别是每个模块的作用。
Nested U-Net
VGGBlock
class VGGBlock(nn.Module):
def __init__(self, in_channels, middle_channels, out_channels):
super().__init__()
self.relu = nn.ReLU(inplace=True)
self.conv1 = nn.Conv2d(in_channels, middle_channels, 3, padding=1)
self.bn1 = nn.BatchNorm2d(middle_channels)
self.conv2 = nn.Conv2d(middle_channels, out_channels, 3, padding=1)
self.bn2 = nn.BatchNorm2d(out_channels)
def forward(self, x):
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)
out = self.relu(out)
return out
VGGBlock 是模型中一个核心的卷积块。每个块包含:
- 卷积层:
conv1
和conv2
,都使用3x3
的卷积核,并且使用了 padding=1 保证输出尺寸和输入相同。 - 批归一化层:
bn1
和bn2
,用于加速训练并稳定模型。 - ReLU 激活函数:增加非线性表达能力。
这个块被多次重复调用,构成 U-Net 和 Nested U-Net 的基础卷积操作。
- 卷积层:
NestedUNet
class NestedUNet(nn.Module):
def __init__(self, num_classes=2, input_channels=2, deep_supervision=False, **kwargs):
super().__init__()
nb_filter = [32, 64, 128, 256, 512]
self.deep_supervision = deep_supervision
self.pool = nn.MaxPool2d(2, 2)
self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
# 定义每一层的卷积模块
self.conv0_0 = VGGBlock(input_channels, nb_filter[0], nb_filter[0])
self.conv1_0 = VGGBlock(nb_filter[0], nb_filter[1], nb_filter[1])
self.conv2_0 = VGGBlock(nb_filter[1], nb_filter[2], nb_filter[2])
self.conv3_0 = VGGBlock(nb_filter[2], nb_filter[3], nb_filter[3])
self.conv4_0 = VGGBlock(nb_filter[3], nb_filter[4], nb_filter[4])
# 定义嵌套的卷积模块(即跳跃连接)
self.conv0_1 = VGGBlock(nb_filter[0]+nb_filter[1], nb_filter[0], nb_filter[0])
self.conv1_1 = VGGBlock(nb_filter[1]+nb_filter[2], nb_filter[1], nb_filter[1])
self.conv2_1 = VGGBlock(nb_filter[2]+nb_filter[3], nb_filter[2], nb_filter[2])
self.conv3_1 = VGGBlock(nb_filter[3]+nb_filter[4], nb_filter[3], nb_filter[3])
self.conv0_2 = VGGBlock(nb_filter[0]*2+nb_filter[1], nb_filter[0], nb_filter[0])
self.conv1_2 = VGGBlock(nb_filter[1]*2+nb_filter[2], nb_filter[1], nb_filter[1])
self.conv2_2 = VGGBlock(nb_filter[2]*2+nb_filter[3], nb_filter[2], nb_filter[2])
self.conv0_3 = VGGBlock(nb_filter[0]*3+nb_filter[1], nb_filter[0], nb_filter[0])
self.conv1_3 = VGGBlock(nb_filter[1]*3+nb_filter[2], nb_filter[1], nb_filter[1])
self.conv0_4 = VGGBlock(nb_filter[0]*4+nb_filter[1], nb_filter[0], nb_filter[0])
# 最终的输出层,支持深度监督(Deep Supervision)
if self.deep_supervision:
self.final1 = nn.Conv2d(nb_filter[0], num_classes, kernel_size=1)
self.final2 = nn.Conv2d(nb_filter[0], num_classes, kernel_size=1)
self.final3 = nn.Conv2d(nb_filter[0], num_classes, kernel_size=1)
self.final4 = nn.Conv2d(nb_filter[0], num_classes, kernel_size=1)
else:
self.final = nn.Conv2d(nb_filter[0], num_classes, kernel_size=1)
主要模块:
卷积层:模型的每一层都是由
VGGBlock
组成。每一层的输出通道数逐渐增加(32, 64, 128, 256, 512),然后在之后的嵌套层中通过跳跃连接进一步进行融合。跳跃连接:这种设计是 Nested U-Net 的关键,每一层的输出不仅用于下一层,还和其他层的输出拼接(concatenate)。这种设计帮助保留了更多的细节信息,并改善了分割精度。
上采样(Upsampling):通过
Upsample
将图像尺寸增大,并进行跳跃连接后,再进行卷积操作。深度监督(Deep Supervision):通过在多个阶段产生输出,增强模型的学习效果。这是 Nested U-Net 的一个特点,可以让模型在不同的深度层次上进行监督,提高性能。
Forward 方法
def forward(self, input):
# 各种卷积操作
x0_0 = self.conv0_0(input)
x1_0 = self.conv1_0(self.pool(x0_0))
x0_1 = self.conv0_1(torch.cat([x0_0, self.up(x1_0)], 1))
x2_0 = self.conv2_0(self.pool(x1_0))
x1_1 = self.conv1_1(torch.cat([x1_0, self.up(x2_0)], 1))
x0_2 = self.conv0_2(torch.cat([x0_0, x0_1, self.up(x1_1)], 1))
# 继续进行嵌套连接和卷积操作,直到最后一层
x3_0 = self.conv3_0(self.pool(x2_0))
x2_1 = self.conv2_1(torch.cat([x2_0, self.up(x3_0)], 1))
x1_2 = self.conv1_2(torch.cat([x1_0, x1_1, self.up(x2_1)], 1))
x0_3 = self.conv0_3(torch.cat([x0_0, x0_1, x0_2, self.up(x1_2)], 1))
x4_0 = self.conv4_0(self.pool(x3_0))
x3_1 = self.conv3_1(torch.cat([x3_0, self.up(x4_0)], 1))
x2_2 = self.conv2_2(torch.cat([x2_0, x2_1, self.up(x3_1)], 1))
x1_3 = self.conv1_3(torch.cat([x1_0, x1_1, x1_2, self.up(x2_2)], 1))
x0_4 = self.conv0_4(torch.cat([x0_0, x0_1, x0_2, x0_3, self.up(x1_3)], 1))
if self.deep_supervision:
output1 = self.final1(x0_1)
output2 = self.final2(x0_2)
output3 = self.final3(x0_3)
output4 = self.final4(x0_4)
return [output1, output2, output3, output4]
else:
output = self.final(x0_4)
return output
- 卷积和池化操作:通过
self.pool
进行下采样(池化),通过self.up
进行上采样(反卷积),并通过拼接(torch.cat
)将不同层的输出合并起来。 - 深度监督输出:如果启用深度监督,会在多个中间层输出结果;否则,只在最后输出结果。