CRNN 卷积循环神经网络


CRNN 卷积循环神经网络

CRNN(Convolutional Recurrent Neural Network,卷积循环神经网络)是一种结合了卷积神经网络(CNN)和循环神经网络(RNN)的混合深度学习模型,主要用于处理序列建模与预测 任务,尤其在图像文本识别(如OCR)语音识别手写识别 等领域表现优异。以下是其核心结构和特点:

1. 模型结构

CRNN通常由三部分组成:

  1. CNN特征提取层
    • 使用类似VGG或ResNet的卷积网络,提取输入图像的局部特征(如边缘、纹理、字符形状等)。
    • 例如:输入一张自然场景的文字图像,CNN会将其转换为特征序列(每个特征对应图像中的一列或一块区域)。
  2. RNN序列建模层
    • 通常采用双向LSTM(Bi-LSTM)或GRU,捕捉特征序列中的上下文依赖关系
    • 例如:在文本识别中,RNN能学习字符间的顺序关系(如“hello”中的“h”后接“e”的概率)。
  3. CTC解码层
    • 使用Connectionist Temporal Classification (CTC) 损失函数,解决输入序列与输出标签长度不一致的问题。
    • 例如:将RNN输出的变长特征序列映射为最终文本标签(如“hello”),无需预先对齐输入与输出。

示例:OCR中的CRNN工作流程

  1. 输入 :一张包含文字的图像(如“HELLO”)。
  2. CNN :提取图像的特征序列(每个特征对应一个字符区域)。
  3. RNN :捕捉字符间的上下文关系(如“H”后接“E”的概率更高)。
  4. CTC :将RNN输出的概率分布转换为最终文本标签(如“HELLO”),无需字符位置对齐。

CRNN通过融合CNN的空间特征提取能力和RNN的序列建模能力,成为处理序列数据的经典架构,尤其在文本识别领域具有重要地位。

代码示例

以下是一个使用PyTorch实现CRNN(CNN+BiLSTM+CTC)的示例代码,适用于图像文本识别(OCR)任务:

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image
import numpy as np
import random

# ---------------------
# 1. 定义CRNN模型
# ---------------------
class CRNN(nn.Module):
    def __init__(self, num_classes, hidden_size=256):
        super(CRNN, self).__init__()
        # CNN部分 (特征提取)
        self.cnn = nn.Sequential(
            nn.Conv2d(1, 64, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2),  # 输出尺寸: (batch, 64, H/2, W/2)
            
            nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2),  # 输出尺寸: (batch, 128, H/4, W/4)
            
            nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=(2, 1), stride=(2, 1)),  # 输出尺寸: (batch, 256, H/8, W/4)
            
            nn.Conv2d(256, 512, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(512),
            nn.ReLU(),
            nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=(2, 1), stride=(2, 1)),  # 输出尺寸: (batch, 512, H/16, W/4)
            
            nn.Conv2d(512, 512, kernel_size=2, stride=1, padding=0)  # 输出尺寸: (batch, 512, 1, W/4-1)
        )
        
        # RNN部分 (序列建模)
        self.rnn = nn.Sequential(
            nn.LSTM(512, hidden_size, bidirectional=True, batch_first=True),
            nn.Dropout(0.5),
            nn.LSTM(hidden_size * 2, hidden_size, bidirectional=True, batch_first=True)
        )
        
        # 全连接层 (输出字符概率)
        self.fc = nn.Linear(hidden_size * 2, num_classes)

    def forward(self, x):
        # CNN特征提取
        x = self.cnn(x)  # (batch, 512, 1, W)
        x = x.squeeze(2)  # 去掉高度维度 (batch, 512, W)
        x = x.permute(0, 2, 1)  # 转换为 (batch, W, 512)
        
        # RNN处理序列
        x, _ = self.rnn(x)  # (batch, W, hidden_size*2)
        x = self.fc(x)  # (batch, W, num_classes)
        
        return x

# ---------------------
# 2. 数据预处理与加载
# ---------------------
class TextDataset(Dataset):
    def __init__(self, data_size=1000, max_length=10, img_width=128, img_height=32):
        self.data_size = data_size
        self.max_length = max_length
        self.img_width = img_width
        self.img_height = img_height
        self.characters = 'abcdefghijklmnopqrstuvwxyz0123456789'
        self.char2idx = {char: idx+1 for idx, char in enumerate(self.characters)}
        self.idx2char = {idx+1: char for idx, char in enumerate(self.characters)}
        self.transform = transforms.Compose([
            transforms.Resize((img_height, img_width)),
            transforms.Grayscale(),
            transforms.ToTensor()
        ])

    def generate_sample(self):
        # 生成随机文本图像(示例)
        text = ''.join(random.choices(self.characters, k=random.randint(1, self.max_length)))
        img = Image.new('L', (self.img_width, self.img_height), color=255)
        return img, text

    def __len__(self):
        return self.data_size

    def __getitem__(self, idx):
        img, text = self.generate_sample()
        img = self.transform(img)
        label = [self.char2idx[c] for c in text]
        return img, torch.tensor(label, dtype=torch.long)

# ---------------------
# 3. 训练配置
# ---------------------
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
num_classes = len(TextDataset().characters) + 1  # +1 for blank token in CTC
model = CRNN(num_classes).to(device)
criterion = nn.CTCLoss(blank=0, zero_infinity=True)
optimizer = optim.Adam(model.parameters(), lr=0.001)

# 数据加载器
train_dataset = TextDataset(data_size=1000)
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)

# ---------------------
# 4. 训练循环
# ---------------------
def train_epoch(model, dataloader, criterion, optimizer, device):
    model.train()
    epoch_loss = 0
    for images, labels in dataloader:
        images = images.to(device)
        batch_size = images.size(0)
        
        # 网络前向传播
        outputs = model(images)
        output_lengths = torch.full(size=(batch_size,), fill_value=outputs.size(1), dtype=torch.long)
        
        # 处理标签
        label_lengths = torch.tensor([len(label) for label in labels], dtype=torch.long)
        labels = torch.cat(labels).to(device)
        
        # 计算CTC损失
        loss = criterion(
            outputs.permute(1, 0, 2).log_softmax(2),  # CTC expects (T, N, C)
            labels,
            output_lengths,
            label_lengths
        )
        
        # 反向传播
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        epoch_loss += loss.item()
    return epoch_loss / len(dataloader)

# ---------------------
# 5. 解码预测结果
# ---------------------
def decode_prediction(output, label2char):
    output = F.softmax(output, dim=2)
    output = output.argmax(2).cpu().numpy()  # (batch, T)
    
    predictions = []
    for sample in output:
        pred = []
        prev_char = -1
        for char_idx in sample:
            if char_idx != 0 and char_idx != prev_char:  # 跳过blank token和重复字符
                pred.append(label2char[char_idx])
            prev_char = char_idx
        predictions.append(''.join(pred))
    return predictions

# ---------------------
# 6. 训练与测试
# ---------------------
num_epochs = 10
for epoch in range(num_epochs):
    loss = train_epoch(model, train_loader, criterion, optimizer, device)
    print(f"Epoch {epoch+1}/{num_epochs}, Loss: {loss:.4f}")

# 测试示例
def test_model(model, test_loader, label2char):
    model.eval()
    with torch.no_grad():
        images, labels = next(iter(test_loader))
        images = images.to(device)
        outputs = model(images)
        preds = decode_prediction(outputs, label2char)
        
        for i in range(5):
            print(f"True: {labels[i]}, Pred: {preds[i]}")

test_loader = DataLoader(TextDataset(data_size=5), batch_size=5)
test_model(model, test_loader, train_dataset.idx2char)

Author: qwq小小舒
Reprint policy: All articles in this blog are used except for special statements CC BY 4.0 reprint policy. If reproduced, please indicate source qwq小小舒 !
  TOC