VAE(Variational Autoencoder)とは
VAE(Variational Autoencoder)は、画像を「圧縮」と「復元」するためのニューラルネットワークです。
Stable Diffusion では、画像生成の効率を高めるために重要な役割を担っています。
VAEは主に次の2つの構成要素から成り立っています:
- Encoder(エンコーダー)
画像(RGB)を潜在表現(latent)に圧縮する - Decoder(デコーダー)
潜在表現(latent)を画像に復元する
Stable DiffusionにおけるVAEの役割
Stable Diffusionでは、直接画像を生成するのではなく、次のような流れになっています:
- ノイズから潜在空間(latent)を生成
- そのlatentをVAEのdecoderで画像に変換
また、画像を入力として使う場合(img2imgなど)は:
- 画像をVAEのencoderでlatentに変換
- latentを加工・生成
- decoderで画像に戻す
つまり:
| 処理 | 使用する部分 |
|---|---|
| 画像 → latent | Encoder |
| latent → 画像 | Decoder |
safetensorsのVAEファイルについて
wan_2.1_vae.safetensors のようなファイルには、通常:
- encoder の重み
- decoder の重み
が1つのファイルにまとめて保存されています。
内部的には以下のようなキーで分かれています:
encoder.*decoder.*
ONNX化のために分離する理由
ONNXに変換する際は、以下のように分けて使うことが多いです:
vae_encoder.onnx(画像 → latent)vae_decoder.onnx(latent → 画像)
理由:
- 推論用途では decoder だけ使うケースが多い
- 軽量化・最適化がしやすい
- モバイルやWeb環境で扱いやすい
ONNXに変換する手順
以下は、wan_2.1_vae.safetensors を
encoder / decoder に分けて ONNX に書き出すサンプルコードです。
必要ライブラリ
pip install torch diffusers safetensors onnx
サンプルコード(完全版)
import torch
from safetensors.torch import load_file
from diffusers import AutoencoderKL
import warnings
warnings.filterwarnings("ignore")
# =========================
# 1. VAEのロード(修正ポイント)
# =========================
vae_path = "wan_2.1_vae.safetensors"
# from_single_file が最も安定
vae = AutoencoderKL.from_single_file(vae_path)
vae.eval()
# VAEのスケーリングファクターを確認(通常 SD系は 0.18215 前後)
print(f"VAE scaling_factor: {vae.config.scaling_factor}")
# =========================
# 2. Encoder / Decoder の準備
# =========================
# Encoder部分(画像 → latent distribution)
encoder = vae.encoder
quant_conv = vae.quant_conv # posteriorを4chに変換するconv
# Decoder部分(latent → 画像)
decoder = vae.decoder
post_quant_conv = vae.post_quant_conv # latentをDecoderに入力する前のconv
# =========================
# 3. ダミー入力作成
# =========================
# 例: 512x512画像(任意サイズでdynamic対応)
dummy_image = torch.randn(1, 3, 512, 512) # 値域は後で -1~1 に正規化
# latentは通常 1/8 スケール + 4チャネル
dummy_latent = torch.randn(1, 4, 64, 64)
# =========================
# 4. Encoder ONNX エクスポート(画像 → latent mean)
# =========================
class VAEEncoder(torch.nn.Module):
def __init__(self, encoder, quant_conv):
super().__init__()
self.encoder = encoder
self.quant_conv = quant_conv
def forward(self, x):
# Stable Diffusionの標準的な前処理: [-1, 1] に正規化
x = x * 2.0 - 1.0
h = self.encoder(x)
moments = self.quant_conv(h) # shape: (B, 8, h/8, w/8)
mean, _ = moments.chunk(2, dim=1) # meanのみ使用(4ch)
return mean
vae_encoder = VAEEncoder(encoder, quant_conv).eval()
torch.onnx.export(
vae_encoder,
dummy_image,
"vae_encoder.onnx",
input_names=["image"],
output_names=["latent"], # meanのみ出力
opset_version=18,
do_constant_folding=True,
dynamic_axes={
"image": {0: "batch", 2: "height", 3: "width"},
"latent": {0: "batch", 2: "height", 3: "width"},
},
export_params=True,
)
# =========================
# 5. Decoder ONNX エクスポート(latent → 画像)
# =========================
class VAEDecoder(torch.nn.Module):
def __init__(self, post_quant_conv, decoder):
super().__init__()
self.post_quant_conv = post_quant_conv
self.decoder = decoder
def forward(self, latent):
# latentをスケーリング(推論時は通常 1/scaling_factor で補正)
latent = latent / vae.config.scaling_factor
z = self.post_quant_conv(latent)
image = self.decoder(z)
return image
vae_decoder = VAEDecoder(post_quant_conv, decoder).eval()
torch.onnx.export(
vae_decoder,
dummy_latent,
"vae_decoder.onnx",
input_names=["latent"],
output_names=["image"],
opset_version=18,
do_constant_folding=True,
dynamic_axes={
"latent": {0: "batch", 2: "height", 3: "width"},
"image": {0: "batch", 2: "height", 3: "width"},
},
export_params=True,
)
print("ONNX export completed!")
print("生成ファイル: vae_encoder.onnx / vae_decoder.onnx")
補足ポイント
解像度について
- VAEは通常「1/8スケール」で動作します
- 512×512 → latentは 64×64
- 任意サイズに対応するには
dynamic_axesを指定するのが重要です
encoder / decoder単体利用の注意
分離したONNXは便利ですが:
- 通常のStable Diffusionパイプラインではそのまま使えない
- 自前で推論パイプラインを組む必要がある
よくある用途
- Web(WebGPU / WebAssembly)での推論
- モバイルアプリ(Android / iOS)
- 高速推論エンジン(TensorRTなど)への変換
まとめ
- VAEは「画像 ↔ latent」を変換する重要なモデル
- safetensorsにはencoderとdecoderが両方含まれている
- ONNX化する際は分離して書き出すのが一般的
- PyTorchから簡単にexport可能


