CRNN 卷积循环神经网络
CRNN(Convolutional Recurrent Neural Network,卷积循环神经网络)是一种结合了卷积神经网络(CNN)和循环神经网络(RNN)的混合深度学习模型,主要用于处理序列建模与预测 任务,尤其在图像文本识别(如OCR) 、语音识别 和手写识别 等领域表现优异。以下是其核心结构和特点:
1. 模型结构
CRNN通常由三部分组成:
- CNN特征提取层
- 使用类似VGG或ResNet的卷积网络,提取输入图像的局部特征(如边缘、纹理、字符形状等)。
- 例如:输入一张自然场景的文字图像,CNN会将其转换为特征序列(每个特征对应图像中的一列或一块区域)。
- RNN序列建模层
- 通常采用双向LSTM(Bi-LSTM)或GRU,捕捉特征序列中的上下文依赖关系 。
- 例如:在文本识别中,RNN能学习字符间的顺序关系(如“hello”中的“h”后接“e”的概率)。
- CTC解码层
- 使用Connectionist Temporal Classification (CTC) 损失函数,解决输入序列与输出标签长度不一致的问题。
- 例如:将RNN输出的变长特征序列映射为最终文本标签(如“hello”),无需预先对齐输入与输出。
示例:OCR中的CRNN工作流程
- 输入 :一张包含文字的图像(如“HELLO”)。
- CNN :提取图像的特征序列(每个特征对应一个字符区域)。
- RNN :捕捉字符间的上下文关系(如“H”后接“E”的概率更高)。
- CTC :将RNN输出的概率分布转换为最终文本标签(如“HELLO”),无需字符位置对齐。
CRNN通过融合CNN的空间特征提取能力和RNN的序列建模能力,成为处理序列数据的经典架构,尤其在文本识别领域具有重要地位。
代码示例
以下是一个使用PyTorch实现CRNN(CNN+BiLSTM+CTC)的示例代码,适用于图像文本识别(OCR)任务:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image
import numpy as np
import random
# ---------------------
# 1. 定义CRNN模型
# ---------------------
class CRNN(nn.Module):
def __init__(self, num_classes, hidden_size=256):
super(CRNN, self).__init__()
# CNN部分 (特征提取)
self.cnn = nn.Sequential(
nn.Conv2d(1, 64, kernel_size=3, stride=1, padding=1),
nn.ReLU(),
nn.MaxPool2d(kernel_size=2, stride=2), # 输出尺寸: (batch, 64, H/2, W/2)
nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1),
nn.ReLU(),
nn.MaxPool2d(kernel_size=2, stride=2), # 输出尺寸: (batch, 128, H/4, W/4)
nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1),
nn.ReLU(),
nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1),
nn.ReLU(),
nn.MaxPool2d(kernel_size=(2, 1), stride=(2, 1)), # 输出尺寸: (batch, 256, H/8, W/4)
nn.Conv2d(256, 512, kernel_size=3, stride=1, padding=1),
nn.BatchNorm2d(512),
nn.ReLU(),
nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1),
nn.ReLU(),
nn.MaxPool2d(kernel_size=(2, 1), stride=(2, 1)), # 输出尺寸: (batch, 512, H/16, W/4)
nn.Conv2d(512, 512, kernel_size=2, stride=1, padding=0) # 输出尺寸: (batch, 512, 1, W/4-1)
)
# RNN部分 (序列建模)
self.rnn = nn.Sequential(
nn.LSTM(512, hidden_size, bidirectional=True, batch_first=True),
nn.Dropout(0.5),
nn.LSTM(hidden_size * 2, hidden_size, bidirectional=True, batch_first=True)
)
# 全连接层 (输出字符概率)
self.fc = nn.Linear(hidden_size * 2, num_classes)
def forward(self, x):
# CNN特征提取
x = self.cnn(x) # (batch, 512, 1, W)
x = x.squeeze(2) # 去掉高度维度 (batch, 512, W)
x = x.permute(0, 2, 1) # 转换为 (batch, W, 512)
# RNN处理序列
x, _ = self.rnn(x) # (batch, W, hidden_size*2)
x = self.fc(x) # (batch, W, num_classes)
return x
# ---------------------
# 2. 数据预处理与加载
# ---------------------
class TextDataset(Dataset):
def __init__(self, data_size=1000, max_length=10, img_width=128, img_height=32):
self.data_size = data_size
self.max_length = max_length
self.img_width = img_width
self.img_height = img_height
self.characters = 'abcdefghijklmnopqrstuvwxyz0123456789'
self.char2idx = {char: idx+1 for idx, char in enumerate(self.characters)}
self.idx2char = {idx+1: char for idx, char in enumerate(self.characters)}
self.transform = transforms.Compose([
transforms.Resize((img_height, img_width)),
transforms.Grayscale(),
transforms.ToTensor()
])
def generate_sample(self):
# 生成随机文本图像(示例)
text = ''.join(random.choices(self.characters, k=random.randint(1, self.max_length)))
img = Image.new('L', (self.img_width, self.img_height), color=255)
return img, text
def __len__(self):
return self.data_size
def __getitem__(self, idx):
img, text = self.generate_sample()
img = self.transform(img)
label = [self.char2idx[c] for c in text]
return img, torch.tensor(label, dtype=torch.long)
# ---------------------
# 3. 训练配置
# ---------------------
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
num_classes = len(TextDataset().characters) + 1 # +1 for blank token in CTC
model = CRNN(num_classes).to(device)
criterion = nn.CTCLoss(blank=0, zero_infinity=True)
optimizer = optim.Adam(model.parameters(), lr=0.001)
# 数据加载器
train_dataset = TextDataset(data_size=1000)
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
# ---------------------
# 4. 训练循环
# ---------------------
def train_epoch(model, dataloader, criterion, optimizer, device):
model.train()
epoch_loss = 0
for images, labels in dataloader:
images = images.to(device)
batch_size = images.size(0)
# 网络前向传播
outputs = model(images)
output_lengths = torch.full(size=(batch_size,), fill_value=outputs.size(1), dtype=torch.long)
# 处理标签
label_lengths = torch.tensor([len(label) for label in labels], dtype=torch.long)
labels = torch.cat(labels).to(device)
# 计算CTC损失
loss = criterion(
outputs.permute(1, 0, 2).log_softmax(2), # CTC expects (T, N, C)
labels,
output_lengths,
label_lengths
)
# 反向传播
optimizer.zero_grad()
loss.backward()
optimizer.step()
epoch_loss += loss.item()
return epoch_loss / len(dataloader)
# ---------------------
# 5. 解码预测结果
# ---------------------
def decode_prediction(output, label2char):
output = F.softmax(output, dim=2)
output = output.argmax(2).cpu().numpy() # (batch, T)
predictions = []
for sample in output:
pred = []
prev_char = -1
for char_idx in sample:
if char_idx != 0 and char_idx != prev_char: # 跳过blank token和重复字符
pred.append(label2char[char_idx])
prev_char = char_idx
predictions.append(''.join(pred))
return predictions
# ---------------------
# 6. 训练与测试
# ---------------------
num_epochs = 10
for epoch in range(num_epochs):
loss = train_epoch(model, train_loader, criterion, optimizer, device)
print(f"Epoch {epoch+1}/{num_epochs}, Loss: {loss:.4f}")
# 测试示例
def test_model(model, test_loader, label2char):
model.eval()
with torch.no_grad():
images, labels = next(iter(test_loader))
images = images.to(device)
outputs = model(images)
preds = decode_prediction(outputs, label2char)
for i in range(5):
print(f"True: {labels[i]}, Pred: {preds[i]}")
test_loader = DataLoader(TextDataset(data_size=5), batch_size=5)
test_model(model, test_loader, train_dataset.idx2char)