VAE(Variational Autoencoder)とは

VAE(Variational Autoencoder)は、画像を「圧縮」と「復元」するためのニューラルネットワークです。
Stable Diffusion では、画像生成の効率を高めるために重要な役割を担っています。

VAEは主に次の2つの構成要素から成り立っています:

  • Encoder(エンコーダー)
    画像(RGB)を潜在表現(latent)に圧縮する
  • Decoder(デコーダー)
    潜在表現(latent)を画像に復元する

Stable DiffusionにおけるVAEの役割

Stable Diffusionでは、直接画像を生成するのではなく、次のような流れになっています:

  1. ノイズから潜在空間(latent)を生成
  2. そのlatentをVAEのdecoderで画像に変換

また、画像を入力として使う場合(img2imgなど)は:

  1. 画像をVAEのencoderでlatentに変換
  2. latentを加工・生成
  3. decoderで画像に戻す

つまり:

処理 使用する部分
画像 → latent Encoder
latent → 画像 Decoder

safetensorsのVAEファイルについて

wan_2.1_vae.safetensors のようなファイルには、通常:

  • encoder の重み
  • decoder の重み

1つのファイルにまとめて保存されています。

内部的には以下のようなキーで分かれています:

  • encoder.*
  • decoder.*

ONNX化のために分離する理由

ONNXに変換する際は、以下のように分けて使うことが多いです:

  • vae_encoder.onnx(画像 → latent)
  • vae_decoder.onnx(latent → 画像)

理由:

  • 推論用途では decoder だけ使うケースが多い
  • 軽量化・最適化がしやすい
  • モバイルやWeb環境で扱いやすい

ONNXに変換する手順

以下は、wan_2.1_vae.safetensors
encoder / decoder に分けて ONNX に書き出すサンプルコードです。

必要ライブラリ

pip install torch diffusers safetensors onnx

サンプルコード(完全版)

import torch
from safetensors.torch import load_file
from diffusers import AutoencoderKL
import warnings
warnings.filterwarnings("ignore")

# =========================
# 1. VAEのロード(修正ポイント)
# =========================
vae_path = "wan_2.1_vae.safetensors"

# from_single_file が最も安定
vae = AutoencoderKL.from_single_file(vae_path)
vae.eval()

# VAEのスケーリングファクターを確認(通常 SD系は 0.18215 前後)
print(f"VAE scaling_factor: {vae.config.scaling_factor}")

# =========================
# 2. Encoder / Decoder の準備
# =========================
# Encoder部分(画像 → latent distribution)
encoder = vae.encoder
quant_conv = vae.quant_conv   # posteriorを4chに変換するconv

# Decoder部分(latent → 画像)
decoder = vae.decoder
post_quant_conv = vae.post_quant_conv  # latentをDecoderに入力する前のconv

# =========================
# 3. ダミー入力作成
# =========================
# 例: 512x512画像(任意サイズでdynamic対応)
dummy_image = torch.randn(1, 3, 512, 512)   # 値域は後で -1~1 に正規化

# latentは通常 1/8 スケール + 4チャネル
dummy_latent = torch.randn(1, 4, 64, 64)

# =========================
# 4. Encoder ONNX エクスポート(画像 → latent mean)
# =========================
class VAEEncoder(torch.nn.Module):
    def __init__(self, encoder, quant_conv):
        super().__init__()
        self.encoder = encoder
        self.quant_conv = quant_conv

    def forward(self, x):
        # Stable Diffusionの標準的な前処理: [-1, 1] に正規化
        x = x * 2.0 - 1.0
        h = self.encoder(x)
        moments = self.quant_conv(h)          # shape: (B, 8, h/8, w/8)
        mean, _ = moments.chunk(2, dim=1)     # meanのみ使用(4ch)
        return mean

vae_encoder = VAEEncoder(encoder, quant_conv).eval()

torch.onnx.export(
    vae_encoder,
    dummy_image,
    "vae_encoder.onnx",
    input_names=["image"],
    output_names=["latent"],                  # meanのみ出力
    opset_version=18,
    do_constant_folding=True,
    dynamic_axes={
        "image": {0: "batch", 2: "height", 3: "width"},
        "latent": {0: "batch", 2: "height", 3: "width"},
    },
    export_params=True,
)

# =========================
# 5. Decoder ONNX エクスポート(latent → 画像)
# =========================
class VAEDecoder(torch.nn.Module):
    def __init__(self, post_quant_conv, decoder):
        super().__init__()
        self.post_quant_conv = post_quant_conv
        self.decoder = decoder

    def forward(self, latent):
        # latentをスケーリング(推論時は通常 1/scaling_factor で補正)
        latent = latent / vae.config.scaling_factor
        z = self.post_quant_conv(latent)
        image = self.decoder(z)
        return image

vae_decoder = VAEDecoder(post_quant_conv, decoder).eval()

torch.onnx.export(
    vae_decoder,
    dummy_latent,
    "vae_decoder.onnx",
    input_names=["latent"],
    output_names=["image"],
    opset_version=18,
    do_constant_folding=True,
    dynamic_axes={
        "latent": {0: "batch", 2: "height", 3: "width"},
        "image": {0: "batch", 2: "height", 3: "width"},
    },
    export_params=True,
)

print("ONNX export completed!")
print("生成ファイル: vae_encoder.onnx  /  vae_decoder.onnx")

補足ポイント

解像度について

  • VAEは通常「1/8スケール」で動作します
    • 512×512 → latentは 64×64
  • 任意サイズに対応するには dynamic_axes を指定するのが重要です

encoder / decoder単体利用の注意

分離したONNXは便利ですが:

  • 通常のStable Diffusionパイプラインではそのまま使えない
  • 自前で推論パイプラインを組む必要がある

よくある用途

  • Web(WebGPU / WebAssembly)での推論
  • モバイルアプリ(Android / iOS)
  • 高速推論エンジン(TensorRTなど)への変換

まとめ

  • VAEは「画像 ↔ latent」を変換する重要なモデル
  • safetensorsにはencoderとdecoderが両方含まれている
  • ONNX化する際は分離して書き出すのが一般的
  • PyTorchから簡単にexport可能