当前位置: 首页 > news >正文

PyTorch 的 CRNN 验证码识别 全流程实战

数据生成(合成验证码)

数据集与 DataLoader(含 collate)

模型实现(CRNN: CNN + BiLSTM + CTC)

训练脚本(含 loss / checkpoint)

评估与推理(greedy decode 与示例)

简易 Flask 部署接口

训练超参与实验建议

优化策略与常见问题

合法合规与伦理提醒

1 环境准备
建议使用 Python 3.8+,GPU 可加速训练。

安装必要依赖(示例):

pip install torch torchvision pillow captcha flask
如果使用 GPU,请按你机器选择对应的 torch 版本(官方安装命令)。

2 数据生成(合成验证码)
用 captcha 库快速生成合成数据方便训练。保存文件名里包含标签,便于 Dataset 读取。

保存为 data_gen.py:

data_gen.py

from captcha.image import ImageCaptcha
import random, string, os
from PIL import Image

CHARS = string.digits + string.ascii_uppercase # 0-9 and A-Z

def generate_one(text, path):
image = ImageCaptcha(width=160, height=60)
image.write(text, path)

def random_text(min_len=4, max_len=5):
length = random.randint(min_len, max_len)
return ''.join(random.choices(CHARS, k=length))

def generate_dataset(out_dir='data/train', n=5000):
os.makedirs(out_dir, exist_ok=True)
for i in range(n):
txt = random_text()
fname = f"{txt}_{i}.png"
generate_one(txt, os.path.join(out_dir, fname))

if name == 'main':
generate_dataset('data/train', 8000)
generate_dataset('data/val', 2000)
print('done')
运行 python data_gen.py 会生成训练和验证集。

3 数据集与 DataLoader(含 collate)
用 PyTorch Dataset 读取图片并将标签编码为整数序列(注意 CTC 要求 targets 合并为 1D)。

保存为 dataset.py:

dataset.py

import torch
from torch.utils.data import Dataset
from PIL import Image
import glob
import os
import string

CHARS = string.digits + string.ascii_uppercase
CHAR2IDX = {c: i+1 for i, c in enumerate(CHARS)} # reserve 0 for blank in CTC
IDX2CHAR = {i+1: c for i, c in enumerate(CHARS)}

import torchvision.transforms as transforms

transform = transforms.Compose([
transforms.Resize((32, 160)), # height 32, width 160 fixed for simplicity
transforms.ToTensor(),
transforms.Normalize([0.5], [0.5])
])

class CaptchaDataset(Dataset):
def init(self, folder, transform=transform):
self.files = glob.glob(os.path.join(folder, '*.png'))
self.transform = transform

def __len__(self):return len(self.files)def __getitem__(self, idx):path = self.files[idx]fname = os.path.basename(path)label_str = fname.split('_')[0]img = Image.open(path).convert('L')  # grayscaleimg = self.transform(img)# convert label to indices (list of int)label = [CHAR2IDX[c] for c in label_str]label = torch.tensor(label, dtype=torch.long)return img, label, label_str

def collate_fn(batch):
# batch: list of (img, label, label_str)
imgs = [item[0] for item in batch]
labels = [item[1] for item in batch]
label_strs = [item[2] for item in batch]
imgs = torch.stack(imgs, dim=0)
# for CTC targets need flattened targets and lengths
targets = torch.cat(labels).to(torch.long)
target_lengths = torch.tensor([len(l) for l in labels], dtype=torch.long)
# input_lengths for CTC will be computed from model output time steps
return imgs, targets, target_lengths, label_strs
说明:

我们把 blank 保留为索引 0,因此字符索引从 1 开始。

简化处理:图片统一大小 32x160。若希望支持可变宽,可改为 padding / resize 保留比例并调整 collate。

4 模型实现(CRNN)
CRNN 由简单的 CNN 特征提取器 + BiLSTM + 线性投影到字符类别数量(包含 blank)。CTC 接收形状 (T, N, C) 的对数概率。

保存为 model.py:

model.py

import torch
import torch.nn as nn

class CRNN(nn.Module):
def init(self, imgH=32, nc=1, nclass=1+36, nh=256):
# nclass = num_chars + 1 (blank)
super(CRNN, self).init()
self.nclass = nclass
ks = [3,3,3,3,3,3,2]
ps = [1,1,1,1,1,1,0]
ss = [1,1,1,1,1,1,1]
nm = [64,128,256,256,512,512,512]

    def conv_relu(i, batch_norm=False):nIn = nc if i==0 else nm[i-1]nOut = nm[i]layers = [nn.Conv2d(nIn, nOut, ks[i], ss[i], ps[i])]if batch_norm:layers.append(nn.BatchNorm2d(nOut))layers.append(nn.ReLU(True))return layers# build conv layers similar to CRNN paperself.cnn = nn.Sequential(*conv_relu(0, False),nn.MaxPool2d(2,2), # 64x16x80*conv_relu(1, False),nn.MaxPool2d(2,2), # 128x8x40*conv_relu(2, False),*conv_relu(3, False),nn.MaxPool2d((2,1), (2,1)), # 256x4x40*conv_relu(4, True),*conv_relu(5, True),nn.MaxPool2d((2,1), (2,1)), # 512x2x40*conv_relu(6, True),# now feature map size ~ (batch, 512, 1, width'))# RNN accepts (seq_len, batch, input_size)self.rnn = nn.Sequential(BidirectionalLSTM(512, nh, nh),BidirectionalLSTM(nh, nh, nclass))def forward(self, x):# x: (B, C, H, W)conv = self.cnn(x)b, c, h, w = conv.size()assert h == 1 or h == 2 or h==4, "expecting small height"# collapse height dimensionconv = conv.squeeze(2)  # (b, c, w)conv = conv.permute(2, 0, 1)  # (w, b, c)output = self.rnn(conv)  # (w, b, nclass)return output  # logits (not softmax)

class BidirectionalLSTM(nn.Module):
def init(self, nIn, nHidden, nOut):
super(BidirectionalLSTM, self).init()
self.rnn = nn.LSTM(nIn, nHidden, bidirectional=True)
self.embedding = nn.Linear(nHidden*2, nOut)

def forward(self, input):recurrent, _ = self.rnn(input)T, b, h = recurrent.size()t_rec = recurrent.view(T*b, h)output = self.embedding(t_rec)  # (T*b, nOut)output = output.view(T, b, -1)return output

说明:

nclass = len(CHARS) + 1,+1 是 CTC 的 blank 类。

forward 返回的是 logits (T, N, C),训练时需要对其取 log_softmax -> 使用 nn.CTCLoss(PyTorch 接受 log_probs)。

5 训练脚本(train.py)
训练流程关键点:

模型输出 pred shape (T, N, C),要 log_softmax 后传给 CTCLoss;

input_lengths 为每个样本的时间步长度 T 或计算得到;

targets 为一维长向量,target_lengths 每个标签长度。

保存为 train.py:

train.py

import torch
from torch.utils.data import DataLoader
from dataset import CaptchaDataset, collate_fn
from model import CRNN
import torch.nn.functional as F
import torch.optim as optim
import os

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

def train():
train_dataset = CaptchaDataset('data/train')
val_dataset = CaptchaDataset('data/val')
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True, collate_fn=collate_fn)
val_loader = DataLoader(val_dataset, batch_size=64, shuffle=False, collate_fn=collate_fn)

nclass = 1 + 36
model = CRNN(nclass=nclass).to(device)
optimizer = optim.Adam(model.parameters(), lr=1e-3)
ctc_loss = torch.nn.CTCLoss(blank=0, reduction='mean', zero_infinity=True)best_acc = 0.0
for epoch in range(1, 31):model.train()total_loss = 0for imgs, targets, target_lengths, _ in train_loader:imgs = imgs.to(device)targets = targets.to(device)optimizer.zero_grad()preds = model(imgs)  # (T, N, C)preds_log_softmax = F.log_softmax(preds, dim=2)T, N, C = preds.size()input_lengths = torch.full(size=(N,), fill_value=T, dtype=torch.long)loss = ctc_loss(preds_log_softmax, targets, input_lengths, target_lengths)loss.backward()optimizer.step()total_loss += loss.item()avg_loss = total_loss / len(train_loader)val_acc = validate(model, val_loader)print(f"Epoch {epoch} Loss {avg_loss:.4f} ValAcc {val_acc:.4f}")# save bestif val_acc > best_acc:best_acc = val_acctorch.save(model.state_dict(), "best_crnn.pth")
print("training finished")

def decode_greedy(preds):
# preds: (T, N, C) logits
_, max_idx = preds.max(2) # (T, N)
max_idx = max_idx.transpose(0,1).cpu().numpy() # (N, T)
results = []
from dataset import IDX2CHAR
for seq in max_idx:
# collapse repeats and remove blanks (0)
prev = 0
out = []
for idx in seq:
if idx != prev and idx != 0:
out.append(IDX2CHAR.get(int(idx), ''))
prev = idx
results.append(''.join(out))
return results

def validate(model, val_loader):
model.eval()
correct = 0
total = 0
with torch.no_grad():
for imgs, targets, target_lengths, label_strs in val_loader:
imgs = imgs.to(device)
preds = model(imgs)
preds = F.log_softmax(preds, dim=2)
preds_raw = preds.exp() # probs
pred_texts = decode_greedy(preds_raw)
for p, gt in zip(pred_texts, label_strs):
if p == gt:
correct += 1
total += 1
return correct/total if total>0 else 0

if name == 'main':
train()
说明:

input_lengths 我这里用 T(时间步与批次一致)。若使用可变宽图片需要计算每张图对应的时间步长度(conv 后的宽度)。

decode_greedy 做贪心解码并移除连续重复与 blank。

6 评估与推理(inference)
推理示例 infer.py:

infer.py

import torch
from model import CRNN
import torch.nn.functional as F
from PIL import Image
from torchvision import transforms

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
transform = transforms.Compose([
transforms.Resize((32,160)),
transforms.ToTensor(),
transforms.Normalize([0.5],[0.5])
])

def load_model(path='best_crnn.pth'):
model = CRNN(nclass=1+36).to(device)
model.load_state_dict(torch.load(path, map_location=device))
model.eval()
return model

def predict(model, image_path):
img = Image.open(image_path).convert('L')
x = transform(img).unsqueeze(0).to(device)
with torch.no_grad():
preds = model(x) # (T, N, C)
probs = F.softmax(preds, dim=2)
# greedy decode similar to train decode_greedy but for single sample
_, max_idx = probs.max(2) # (T, N)
seq = max_idx[:,0].cpu().numpy().tolist()
from dataset import IDX2CHAR
prev = 0
out = []
for idx in seq:
if idx != prev and idx != 0:
out.append(IDX2CHAR.get(int(idx),''))
prev = idx
return ''.join(out)

if name == 'main':
model = load_model('best_crnn.pth')
print(predict(model, 'data/val/SAMPLE.png'))
7 简易 Flask 部署示例
部署为简单 HTTP 接口 app.py:

app.py

from flask import Flask, request, jsonify
from infer import load_model, predict
import os

app = Flask(name)
model = load_model('best_crnn.pth')

@app.route('/predict', methods=['POST'])
def predict_api():
if 'file' not in request.files:
return jsonify({'error':'no file'}), 400
f = request.files['file']
fname = 'tmp_upload.png'
f.save(fname)
text = predict(model, fname)
os.remove(fname)
return jsonify({'text': text})

if name == 'main':
app.run(host='0.0.0.0', port=5000)
启动后可用 curl 上传图片测试:

curl -F "file=@captcha.png" http://127.0.0.1:5000/predict

http://www.sczhlp.com/news/10776/

相关文章:

  • 基于 PyTorch 的 CRNN 验证码识别 全流程实战
  • 滑动时间窗口和固定时间窗口的区别
  • 一文讲懂引用传递与值传递
  • 有向图
  • 江科大10-2DS1302可调时钟-个人优化版
  • 3.4.4~3.4.6
  • [瞄准辅助] 实现一种柔和平滑的瞄准辅助
  • 行测2
  • 网络流
  • 机器学习模型漏洞的发现与防御技术
  • 【从零开始实现stm32无刷电机FOC】【实践1/3】 stm32高级定时器
  • Windows 10静默漏洞缓解机制:专为1%人群设计的NtLoadKey3系统调用
  • 初二新初三集训 Part 1
  • 常用命令 - Charlie
  • 2025最新整理PyCharm 2024下载安装教程加免费激活教程
  • R语言绘制单倍型热图
  • # 把时间当作朋友:高效管理的四个关键认知
  • 2025.8.12打卡
  • P3700 [CQOI2017] 小 Q 的表格 题目分析
  • Oracle DBA必备工具:11G命令自定义创建数据库脚本
  • CF2128游记
  • 02011001 语句
  • 75. 颜色分类
  • Vue vs React 多维度剖析: 哪一个更适合大型项目?
  • MarkDown 常用操作
  • python爬虫类 - LittleD
  • 算法[未完成] - LittleD
  • 设计模式 - LittleD
  • 计算机网络 - LittleD
  • java学习(8月12日)