PixelCNN 是一个非常有代表性的 图像自回归生成模型,它可以逐像素地生成图像,尤其是用于建模像素之间的依赖关系。
图像自回归生成模型PixelCNN
🧱 PixelCNN:基本思想
✅ 核心概念:
将一张图像看作一个像素序列,按从左到右、从上到下的顺序逐像素建模, 每一个像素点 $x_{i,j}$ 的生成都依赖于其左边和上边的像素。
📐 数学建模方式
对一张 $n \times n$ 的图像,PixelCNN 建模它的联合分布: \(P(\mathbf{x}) = \prod_{i=1}^n \prod_{j=1}^n P(x_{i,j} \mid x_{<i,j})\) 其中 $x_{<i,j}$ 表示所有在 $(i,j)$ 之前的像素(通常是左侧和上方)。
🧠 结构细节
1. Masked Convolution(掩码卷积)
- 为了保证像素 $x_{i,j}$ 只能看到它之前的像素,PixelCNN 使用了特制的卷积核,称为 掩码卷积。
- 两种掩码:
- Mask A:用于第一层,完全避免“看见当前像素”
- Mask B:用于后续层,可以“看到当前像素的通道之间”
🛡️ 这是为了解决信息泄露问题(不能提前看到要预测的像素)
2. 通道建模(RGB)
- 每个像素包含多个通道(R、G、B)
- 像素内部顺序建模:先生成 R,然后 G,最后 B,每个都依赖前面信息
3. 输出分布
- 对每个像素,输出的是 离散分布(例如 256 个值)
- 通常用 Softmax 对每个通道进行建模
🔄 采样过程(逐像素生成)
- 初始化一个空图像(全为 0 或随机值)
- 按照从左到右、从上到下的顺序生成每个像素
- 每次将生成的像素作为输入,预测下一个像素
🧩 PixelCNN 的优缺点
优点 ✅ | 缺点 ❌ |
---|---|
精确建模像素间依赖关系 | 推理速度慢(逐像素生成) |
结构简单,容易实现 | 长距离像素依赖建模有限 |
不需要隐变量(如 VAE 中的 z) | 图像质量相对较差(早期版本) |
🧩 1. MaskedConv2d(核心掩码卷积)
import torch
import torch.nn as nn
import torch.nn.functional as F
class MaskedConv2d(nn.Conv2d):
def __init__(self, mask_type, in_channels, out_channels, kernel_size, **kwargs):
super().__init__(in_channels, out_channels, kernel_size, **kwargs)
assert mask_type in ('A', 'B'), "mask_type must be 'A' or 'B'"
self.register_buffer("mask", self.weight.data.clone())
_, _, kH, kW = self.weight.size()
self.mask.fill_(1)
yc, xc = kH // 2, kW // 2
self.mask[:, :, yc, xc + (mask_type == 'B'):] = 0
self.mask[:, :, yc + 1:] = 0
def forward(self, x):
self.weight.data *= self.mask
return super().forward(x)
👆 说明:
- Mask A:用于第一层,不允许看到当前像素点,当前像素(中心) 和 右侧像素 全部屏蔽掉!
- Mask B:用于后续层,只屏蔽右侧像素,但保留中心像素本身
一个 7x7 的卷积核被掩码成这样(1 可用,0 被屏蔽):
1 1 1 1 1 1 1
1 1 1 1 1 1 1
1 1 1 1 1 1 1
1 1 1 0 0 0 0 ← 中心行:当前像素和右边都屏蔽
0 0 0 0 0 0 0 ← 下方行全屏蔽
0 0 0 0 0 0 0
0 0 0 0 0 0 0
🏗️ 2. 简单 PixelCNN 网络结构
class SimplePixelCNN(nn.Module):
def __init__(self, input_channels=3, hidden_channels=64, num_layers=7):
super().__init__()
layers = []
# 第一层使用 Mask A
layers.append(MaskedConv2d('A', input_channels, hidden_channels, kernel_size=7, padding=3))
# 后续层使用 Mask B
for _ in range(num_layers - 2):
layers.append(nn.ReLU())
layers.append(MaskedConv2d('B', hidden_channels, hidden_channels, kernel_size=7, padding=3))
# 最后一层输出每个通道的 256 维 softmax logits
layers.append(nn.ReLU())
layers.append(nn.Conv2d(hidden_channels, input_channels * 256, kernel_size=1))
self.net = nn.Sequential(*layers)
def forward(self, x):
out = self.net(x)
# reshape: [B, 3*256, H, W] → [B, 3, 256, H, W]
B, C, H, W = out.shape
out = out.view(B, 3, 256, H, W)
return out
📥 输入输出说明
- 输入:
x
是 [B, 3, H, W] 的图像 tensor,像素值通常是 one-hot 或离散整数。 - 输出:softmax logits(用于对每个像素点每个通道进行分类)
🧪 采样流程
def sample_from_pixelcnn(model, image_shape, device='cuda'):
model.eval()
B, C, H, W = image_shape
image = torch.zeros((B, C, H, W), dtype=torch.float32, device=device)
with torch.no_grad():
for i in range(H):
for j in range(W):
out = model(image) # [B, 3, 256, H, W]
probs = F.softmax(out[:, :, :, i, j], dim=-1) # [B, 3, 256]
pixel = torch.multinomial(probs.view(B * C, -1), 1).view(B, C).float() / 255.0
image[:, :, i, j] = pixel
return image