数据生成(合成验证码)
数据集与 DataLoader(含 collate)
模型实现(CRNN: CNN + BiLSTM + CTC)
训练脚本(含 loss / checkpoint)
评估与推理(greedy decode 与示例)
简易 Flask 部署接口
训练超参与实验建议
优化策略与常见问题
合法合规与伦理提醒
1 环境准备
建议使用 Python 3.8+,GPU 可加速训练。
安装必要依赖(示例):
pip install torch torchvision pillow captcha flask
如果使用 GPU,请按你机器选择对应的 torch 版本(官方安装命令)。
2 数据生成(合成验证码)
用 captcha 库快速生成合成数据方便训练。保存文件名里包含标签,便于 Dataset 读取。
保存为 data_gen.py:
data_gen.py
from captcha.image import ImageCaptcha
import random, string, os
from PIL import Image
CHARS = string.digits + string.ascii_uppercase # 0-9 and A-Z
def generate_one(text, path):
image = ImageCaptcha(width=160, height=60)
image.write(text, path)
def random_text(min_len=4, max_len=5):
length = random.randint(min_len, max_len)
return ''.join(random.choices(CHARS, k=length))
def generate_dataset(out_dir='data/train', n=5000):
os.makedirs(out_dir, exist_ok=True)
for i in range(n):
txt = random_text()
fname = f"{txt}_{i}.png"
generate_one(txt, os.path.join(out_dir, fname))
if name == 'main':
generate_dataset('data/train', 8000)
generate_dataset('data/val', 2000)
print('done')
运行 python data_gen.py 会生成训练和验证集。
3 数据集与 DataLoader(含 collate)
用 PyTorch Dataset 读取图片并将标签编码为整数序列(注意 CTC 要求 targets 合并为 1D)。
保存为 dataset.py:
dataset.py
import torch
from torch.utils.data import Dataset
from PIL import Image
import glob
import os
import string
CHARS = string.digits + string.ascii_uppercase
CHAR2IDX = {c: i+1 for i, c in enumerate(CHARS)} # reserve 0 for blank in CTC
IDX2CHAR = {i+1: c for i, c in enumerate(CHARS)}
import torchvision.transforms as transforms
transform = transforms.Compose([
transforms.Resize((32, 160)), # height 32, width 160 fixed for simplicity
transforms.ToTensor(),
transforms.Normalize([0.5], [0.5])
])
class CaptchaDataset(Dataset):
def init(self, folder, transform=transform):
self.files = glob.glob(os.path.join(folder, '*.png'))
self.transform = transform
def __len__(self):return len(self.files)def __getitem__(self, idx):path = self.files[idx]fname = os.path.basename(path)label_str = fname.split('_')[0]img = Image.open(path).convert('L') # grayscaleimg = self.transform(img)# convert label to indices (list of int)label = [CHAR2IDX[c] for c in label_str]label = torch.tensor(label, dtype=torch.long)return img, label, label_str
def collate_fn(batch):
# batch: list of (img, label, label_str)
imgs = [item[0] for item in batch]
labels = [item[1] for item in batch]
label_strs = [item[2] for item in batch]
imgs = torch.stack(imgs, dim=0)
# for CTC targets need flattened targets and lengths
targets = torch.cat(labels).to(torch.long)
target_lengths = torch.tensor([len(l) for l in labels], dtype=torch.long)
# input_lengths for CTC will be computed from model output time steps
return imgs, targets, target_lengths, label_strs
说明:
我们把 blank 保留为索引 0,因此字符索引从 1 开始。
简化处理:图片统一大小 32x160。若希望支持可变宽,可改为 padding / resize 保留比例并调整 collate。
4 模型实现(CRNN)
CRNN 由简单的 CNN 特征提取器 + BiLSTM + 线性投影到字符类别数量(包含 blank)。CTC 接收形状 (T, N, C) 的对数概率。
保存为 model.py:
model.py
import torch
import torch.nn as nn
class CRNN(nn.Module):
def init(self, imgH=32, nc=1, nclass=1+36, nh=256):
# nclass = num_chars + 1 (blank)
super(CRNN, self).init()
self.nclass = nclass
ks = [3,3,3,3,3,3,2]
ps = [1,1,1,1,1,1,0]
ss = [1,1,1,1,1,1,1]
nm = [64,128,256,256,512,512,512]
def conv_relu(i, batch_norm=False):nIn = nc if i==0 else nm[i-1]nOut = nm[i]layers = [nn.Conv2d(nIn, nOut, ks[i], ss[i], ps[i])]if batch_norm:layers.append(nn.BatchNorm2d(nOut))layers.append(nn.ReLU(True))return layers# build conv layers similar to CRNN paperself.cnn = nn.Sequential(*conv_relu(0, False),nn.MaxPool2d(2,2), # 64x16x80*conv_relu(1, False),nn.MaxPool2d(2,2), # 128x8x40*conv_relu(2, False),*conv_relu(3, False),nn.MaxPool2d((2,1), (2,1)), # 256x4x40*conv_relu(4, True),*conv_relu(5, True),nn.MaxPool2d((2,1), (2,1)), # 512x2x40*conv_relu(6, True),# now feature map size ~ (batch, 512, 1, width'))# RNN accepts (seq_len, batch, input_size)self.rnn = nn.Sequential(BidirectionalLSTM(512, nh, nh),BidirectionalLSTM(nh, nh, nclass))def forward(self, x):# x: (B, C, H, W)conv = self.cnn(x)b, c, h, w = conv.size()assert h == 1 or h == 2 or h==4, "expecting small height"# collapse height dimensionconv = conv.squeeze(2) # (b, c, w)conv = conv.permute(2, 0, 1) # (w, b, c)output = self.rnn(conv) # (w, b, nclass)return output # logits (not softmax)
class BidirectionalLSTM(nn.Module):
def init(self, nIn, nHidden, nOut):
super(BidirectionalLSTM, self).init()
self.rnn = nn.LSTM(nIn, nHidden, bidirectional=True)
self.embedding = nn.Linear(nHidden*2, nOut)
def forward(self, input):recurrent, _ = self.rnn(input)T, b, h = recurrent.size()t_rec = recurrent.view(T*b, h)output = self.embedding(t_rec) # (T*b, nOut)output = output.view(T, b, -1)return output
说明:
nclass = len(CHARS) + 1,+1 是 CTC 的 blank 类。
forward 返回的是 logits (T, N, C),训练时需要对其取 log_softmax -> 使用 nn.CTCLoss(PyTorch 接受 log_probs)。
5 训练脚本(train.py)
训练流程关键点:
模型输出 pred shape (T, N, C),要 log_softmax 后传给 CTCLoss;
input_lengths 为每个样本的时间步长度 T 或计算得到;
targets 为一维长向量,target_lengths 每个标签长度。
保存为 train.py:
train.py
import torch
from torch.utils.data import DataLoader
from dataset import CaptchaDataset, collate_fn
from model import CRNN
import torch.nn.functional as F
import torch.optim as optim
import os
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
def train():
train_dataset = CaptchaDataset('data/train')
val_dataset = CaptchaDataset('data/val')
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True, collate_fn=collate_fn)
val_loader = DataLoader(val_dataset, batch_size=64, shuffle=False, collate_fn=collate_fn)
nclass = 1 + 36
model = CRNN(nclass=nclass).to(device)
optimizer = optim.Adam(model.parameters(), lr=1e-3)
ctc_loss = torch.nn.CTCLoss(blank=0, reduction='mean', zero_infinity=True)best_acc = 0.0
for epoch in range(1, 31):model.train()total_loss = 0for imgs, targets, target_lengths, _ in train_loader:imgs = imgs.to(device)targets = targets.to(device)optimizer.zero_grad()preds = model(imgs) # (T, N, C)preds_log_softmax = F.log_softmax(preds, dim=2)T, N, C = preds.size()input_lengths = torch.full(size=(N,), fill_value=T, dtype=torch.long)loss = ctc_loss(preds_log_softmax, targets, input_lengths, target_lengths)loss.backward()optimizer.step()total_loss += loss.item()avg_loss = total_loss / len(train_loader)val_acc = validate(model, val_loader)print(f"Epoch {epoch} Loss {avg_loss:.4f} ValAcc {val_acc:.4f}")# save bestif val_acc > best_acc:best_acc = val_acctorch.save(model.state_dict(), "best_crnn.pth")
print("training finished")
def decode_greedy(preds):
# preds: (T, N, C) logits
_, max_idx = preds.max(2) # (T, N)
max_idx = max_idx.transpose(0,1).cpu().numpy() # (N, T)
results = []
from dataset import IDX2CHAR
for seq in max_idx:
# collapse repeats and remove blanks (0)
prev = 0
out = []
for idx in seq:
if idx != prev and idx != 0:
out.append(IDX2CHAR.get(int(idx), ''))
prev = idx
results.append(''.join(out))
return results
def validate(model, val_loader):
model.eval()
correct = 0
total = 0
with torch.no_grad():
for imgs, targets, target_lengths, label_strs in val_loader:
imgs = imgs.to(device)
preds = model(imgs)
preds = F.log_softmax(preds, dim=2)
preds_raw = preds.exp() # probs
pred_texts = decode_greedy(preds_raw)
for p, gt in zip(pred_texts, label_strs):
if p == gt:
correct += 1
total += 1
return correct/total if total>0 else 0
if name == 'main':
train()
说明:
input_lengths 我这里用 T(时间步与批次一致)。若使用可变宽图片需要计算每张图对应的时间步长度(conv 后的宽度)。
decode_greedy 做贪心解码并移除连续重复与 blank。
6 评估与推理(inference)
推理示例 infer.py:
infer.py
import torch
from model import CRNN
import torch.nn.functional as F
from PIL import Image
from torchvision import transforms
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
transform = transforms.Compose([
transforms.Resize((32,160)),
transforms.ToTensor(),
transforms.Normalize([0.5],[0.5])
])
def load_model(path='best_crnn.pth'):
model = CRNN(nclass=1+36).to(device)
model.load_state_dict(torch.load(path, map_location=device))
model.eval()
return model
def predict(model, image_path):
img = Image.open(image_path).convert('L')
x = transform(img).unsqueeze(0).to(device)
with torch.no_grad():
preds = model(x) # (T, N, C)
probs = F.softmax(preds, dim=2)
# greedy decode similar to train decode_greedy but for single sample
_, max_idx = probs.max(2) # (T, N)
seq = max_idx[:,0].cpu().numpy().tolist()
from dataset import IDX2CHAR
prev = 0
out = []
for idx in seq:
if idx != prev and idx != 0:
out.append(IDX2CHAR.get(int(idx),''))
prev = idx
return ''.join(out)
if name == 'main':
model = load_model('best_crnn.pth')
print(predict(model, 'data/val/SAMPLE.png'))
7 简易 Flask 部署示例
部署为简单 HTTP 接口 app.py:
app.py
from flask import Flask, request, jsonify
from infer import load_model, predict
import os
app = Flask(name)
model = load_model('best_crnn.pth')
@app.route('/predict', methods=['POST'])
def predict_api():
if 'file' not in request.files:
return jsonify({'error':'no file'}), 400
f = request.files['file']
fname = 'tmp_upload.png'
f.save(fname)
text = predict(model, fname)
os.remove(fname)
return jsonify({'text': text})
if name == 'main':
app.run(host='0.0.0.0', port=5000)
启动后可用 curl 上传图片测试:
curl -F "file=@captcha.png" http://127.0.0.1:5000/predict