Stable Diffusion ONNX U-Net に別モデルを統合する方法

Stable Diffusion の ONNX 版 U-Net を改造して、別の ONNX モデルを内部に結合することで、

  • IP-Adapter
  • カスタム Attention
  • Control 信号

などを直接モデル内部に組み込むことができます。

この記事では 既存の model.onnx に別の model.onnx を統合する方法を紹介します。
例として、以下の位置に外部モデルを接続します。

/up_blocks.1/attentions.0/transformer_blocks.0/attn2/to_q/MatMul

そして生成された特徴量を

/up_blocks.1/attentions.0/transformer_blocks.0/attn2/to_out.0/MatMul

へ入力します。


全体構造

結合後の概念図

                 to_q/MatMul
                      │
            MultiHeadAttention
                      │
                 外部モデル
                      │
                      ▼
            to_out.0/MatMul

また既存の Reshape_3_output_0 は使用しないため切断します。


必要なライブラリ

pip install onnx onnxruntime onnx-graphsurgeon
  • onnx:モデル読み込み
  • onnx_graphsurgeon:グラフ編集

汎用的な結合手順

  1. メインモデルを読み込む
  2. 結合するサブモデルを読み込む
  3. グラフをマージ
  4. 特定ノードの出力を接続
  5. 不要ノードを切断
  6. 保存

Pythonコード

import onnx
import onnx_graphsurgeon as gs

MAIN_MODEL = "model.onnx"
SUB_MODEL = "submodel.onnx"
OUTPUT_MODEL = "merged_model.onnx"

# -------------------------
# モデル読み込み
# -------------------------

main_graph = gs.import_onnx(onnx.load(MAIN_MODEL))
sub_graph = gs.import_onnx(onnx.load(SUB_MODEL))

# -------------------------
# ノード検索
# -------------------------

def find_node(graph, name):
    for n in graph.nodes:
        if n.name == name:
            return n
    return None

to_q = find_node(
    main_graph,
    "/up_blocks.1/attentions.0/transformer_blocks.0/attn2/to_q/MatMul"
)

to_out = find_node(
    main_graph,
    "/up_blocks.1/attentions.0/transformer_blocks.0/attn2/to_out.0/MatMul"
)

# -------------------------
# 外部入力追加
# -------------------------

ip_k_3 = gs.Variable("ip_k_3", dtype=None, shape=None)
ip_v_3 = gs.Variable("ip_v_3", dtype=None, shape=None)

main_graph.inputs.append(ip_k_3)
main_graph.inputs.append(ip_v_3)

# -------------------------
# サブモデル入力を接続
# -------------------------

# submodel の最初の入力
sub_input = sub_graph.inputs[0]

# to_q の出力をサブモデルに入力
sub_input.outputs.clear()
sub_input.inputs = [to_q.outputs[0]]

# -------------------------
# サブモデルを統合
# -------------------------

main_graph.nodes.extend(sub_graph.nodes)

# -------------------------
# 出力接続
# -------------------------

sub_output = sub_graph.outputs[0]

to_out.inputs[0] = sub_output

# -------------------------
# Reshape_3 切断
# -------------------------

for node in main_graph.nodes:
    if node.name.endswith("Reshape_3"):
        node.outputs = []

# -------------------------
# グラフ整理
# -------------------------

main_graph.cleanup()
main_graph.toposort()

# -------------------------
# 保存
# -------------------------

onnx.save(gs.export_onnx(main_graph), OUTPUT_MODEL)

print("Merged model saved:", OUTPUT_MODEL)

カスタマイズポイント

接続位置

別の層に接続する場合はノード名を変更します。

  • attn1
  • mid_block
  • down_blocks

外部入力の追加

IP-Adapter などの場合、以下を追加できます。

  • ip_k
  • ip_v
  • image_embedding

融合方法

Concat / Add / Attention など、融合方法も自由に変更できます。

  • Concat
  • Add
  • MatMul
  • CustomAttention

応用例

この方法を使うと以下が実装できます。

  • IP-Adapter統合:image embedding を cross attention に追加
  • ControlNet簡易版:条件特徴を U-Net に直接注入
  • LoRA固定化:LoRA をモデル内部に焼き込み
  • カスタムAttention:Transformer構造を差し替え

注意点

ノード名はモデルごとに異なる

Netron で確認するのがおすすめです。

Tensor shape を合わせる

Concat の axis や MatMul の shape は一致させる必要があります。

cleanup() を必ず実行

未接続ノードが残ると ONNX Runtime でエラーになります。


まとめ

ONNX のグラフ編集を使うと、Stable Diffusion の U-Net に別 ONNX モデル・Attentionモジュール・外部条件などを直接結合できます。

基本パターンは次の5ステップです。

  1. モデル読み込み
  2. ノード取得
  3. グラフ結合
  4. 接続変更
  5. cleanup()

これを応用すればかなり自由に Stable Diffusion を改造できます。