你的浏览器不支持canvas

Welcome to my home

图像自回归生成模型PixelCNN

Date: Author: cxb

PixelCNN 是一个非常有代表性的 图像自回归生成模型,它可以逐像素地生成图像,尤其是用于建模像素之间的依赖关系。

图像自回归生成模型PixelCNN


🧱 PixelCNN:基本思想

✅ 核心概念:

将一张图像看作一个像素序列,按从左到右、从上到下的顺序逐像素建模, 每一个像素点 $x_{i,j}$ 的生成都依赖于其左边和上边的像素


📐 数学建模方式

对一张 $n \times n$ 的图像,PixelCNN 建模它的联合分布: \(P(\mathbf{x}) = \prod_{i=1}^n \prod_{j=1}^n P(x_{i,j} \mid x_{<i,j})\) 其中 $x_{<i,j}$ 表示所有在 $(i,j)$ 之前的像素(通常是左侧和上方)。


🧠 结构细节

1. Masked Convolution(掩码卷积)

  • 为了保证像素 $x_{i,j}$ 只能看到它之前的像素,PixelCNN 使用了特制的卷积核,称为 掩码卷积
  • 两种掩码:
    • Mask A:用于第一层,完全避免“看见当前像素”
    • Mask B:用于后续层,可以“看到当前像素的通道之间”

🛡️ 这是为了解决信息泄露问题(不能提前看到要预测的像素)


2. 通道建模(RGB)

  • 每个像素包含多个通道(R、G、B)
  • 像素内部顺序建模:先生成 R,然后 G,最后 B,每个都依赖前面信息

3. 输出分布

  • 对每个像素,输出的是 离散分布(例如 256 个值)
  • 通常用 Softmax 对每个通道进行建模

🔄 采样过程(逐像素生成)

  1. 初始化一个空图像(全为 0 或随机值)
  2. 按照从左到右、从上到下的顺序生成每个像素
  3. 每次将生成的像素作为输入,预测下一个像素

🧩 PixelCNN 的优缺点

优点 ✅ 缺点 ❌
精确建模像素间依赖关系 推理速度慢(逐像素生成)
结构简单,容易实现 长距离像素依赖建模有限
不需要隐变量(如 VAE 中的 z) 图像质量相对较差(早期版本)

🧩 1. MaskedConv2d(核心掩码卷积)

import torch
import torch.nn as nn
import torch.nn.functional as F

class MaskedConv2d(nn.Conv2d):
    def __init__(self, mask_type, in_channels, out_channels, kernel_size, **kwargs):
        super().__init__(in_channels, out_channels, kernel_size, **kwargs)
        assert mask_type in ('A', 'B'), "mask_type must be 'A' or 'B'"
        self.register_buffer("mask", self.weight.data.clone())

        _, _, kH, kW = self.weight.size()
        self.mask.fill_(1)
        yc, xc = kH // 2, kW // 2
		
        self.mask[:, :, yc, xc + (mask_type == 'B'):] = 0
        self.mask[:, :, yc + 1:] = 0

    def forward(self, x):
        self.weight.data *= self.mask
        return super().forward(x)

👆 说明:

  • Mask A:用于第一层,不允许看到当前像素点当前像素(中心)右侧像素 全部屏蔽掉!
  • Mask B:用于后续层,只屏蔽右侧像素,但保留中心像素本身

一个 7x7 的卷积核被掩码成这样(1 可用,0 被屏蔽):

1 1 1 1 1 1 1
1 1 1 1 1 1 1
1 1 1 1 1 1 1
1 1 1 0 0 0 0  ← 中心行:当前像素和右边都屏蔽
0 0 0 0 0 0 0  ← 下方行全屏蔽
0 0 0 0 0 0 0
0 0 0 0 0 0 0

🏗️ 2. 简单 PixelCNN 网络结构

class SimplePixelCNN(nn.Module):
    def __init__(self, input_channels=3, hidden_channels=64, num_layers=7):
        super().__init__()
        layers = []

        # 第一层使用 Mask A
        layers.append(MaskedConv2d('A', input_channels, hidden_channels, kernel_size=7, padding=3))

        # 后续层使用 Mask B
        for _ in range(num_layers - 2):
            layers.append(nn.ReLU())
            layers.append(MaskedConv2d('B', hidden_channels, hidden_channels, kernel_size=7, padding=3))

        # 最后一层输出每个通道的 256 维 softmax logits
        layers.append(nn.ReLU())
        layers.append(nn.Conv2d(hidden_channels, input_channels * 256, kernel_size=1))

        self.net = nn.Sequential(*layers)

    def forward(self, x):
        out = self.net(x)
        # reshape: [B, 3*256, H, W] → [B, 3, 256, H, W]
        B, C, H, W = out.shape
        out = out.view(B, 3, 256, H, W)
        return out

📥 输入输出说明

  • 输入:x 是 [B, 3, H, W] 的图像 tensor,像素值通常是 one-hot 或离散整数。
  • 输出:softmax logits(用于对每个像素点每个通道进行分类)

🧪 采样流程

def sample_from_pixelcnn(model, image_shape, device='cuda'):
    model.eval()
    B, C, H, W = image_shape
    image = torch.zeros((B, C, H, W), dtype=torch.float32, device=device)

    with torch.no_grad():
        for i in range(H):
            for j in range(W):
                out = model(image)  # [B, 3, 256, H, W]
                probs = F.softmax(out[:, :, :, i, j], dim=-1)  # [B, 3, 256]
                pixel = torch.multinomial(probs.view(B * C, -1), 1).view(B, C).float() / 255.0
                image[:, :, i, j] = pixel

    return image