ONNXモデルを読み込む際に、以下のようなエラーに遭遇したことはありませんか?

Node (MultiHeadAttention_28) Op (MultiHeadAttention) [ShapeInferenceError] 
Inputs 0 (query) shall be 3 or 5 dimensions

この記事では、このエラーの意味と原因、さらに解決に重要な「dynamic_axes」の考え方と書き方について、実践的に解説します。


エラーの意味を分解する

まずはエラーメッセージを分解して理解しましょう。

該当箇所

Inputs 0 (query) shall be 3 or 5 dimensions

意味

  • Inputs 0 (query)
    → MultiHeadAttentionの最初の入力(query)
  • shall be 3 or 5 dimensions
    → 3次元または5次元でなければならない

つまり、

「queryとして渡されたテンソルの次元数が想定と違う」

というエラーです。


MultiHeadAttentionが要求する入力形状

MultiHeadAttentionでは、通常以下のような形状が期待されます。

3次元(一般的)

[batch_size, sequence_length, hidden_size]

例:

[1, 77, 768]

5次元(拡張ケース)

[batch_size, num_heads, seq_length, head_dim, extra_dim]

よくある原因

1. 入力が2次元になっている

例:

[77, 768]

→ batch次元が抜けている

2. reshapeミス

本来3次元にすべきところを、誤って4次元などにしている

3. ONNX変換時の設定ミス

特にここで重要になるのが dynamic_axes です


dynamic_axesとは何か?

ONNXにモデルを変換する際、入力サイズを固定するか、可変にするかを指定できます。

dynamic_axesの役割

「この次元は実行時に変わってもいい」という宣言

例えば:

  • バッチサイズは毎回変わる
  • 文章の長さ(seq_length)も変わる

こういった場合に必要です。


dynamic_axesを使わない場合の問題

dynamic_axesを指定しないと:

  • 入力サイズが完全固定される
  • 想定外のshapeで実行するとエラーになる
  • 今回のようなShapeInferenceErrorの原因になる

dynamic_axesの書き方

基本構文はこちら:

dynamic_axes={
    '入力名': {次元番号: '名前'}
}

実践例(MultiHeadAttention対応)

import torch

dummy_input = torch.randn(1, 77, 768)

torch.onnx.export(
    model,
    dummy_input,
    "model.onnx",
    input_names=['query'],
    output_names=['output'],
    dynamic_axes={
        'query': {
            0: 'batch_size',
            1: 'sequence_length'
        },
        'output': {
            0: 'batch_size',
            1: 'sequence_length'
        }
    },
    opset_version=17
)

各指定の意味

設定 意味
0: 'batch_size' バッチサイズは可変
1: 'sequence_length' 文章長も可変
input_names ONNX側の入力名
opset_version=17 新しい演算子(MultiHeadAttention)対応

なぜdynamic_axesが重要なのか

MultiHeadAttentionは内部で複雑なshape計算を行います。

そのため:

  • shapeが少しでも想定とズレると即エラー
  • ONNX Runtimeは厳密にチェックする

柔軟なshapeを許可するdynamic_axesが重要


トラブルシューティングまとめ

エラーが出た場合は以下を確認:

チェックリスト

  • 入力が3次元になっているか
  • batch次元が抜けていないか
  • ONNX変換時にdynamic_axesを指定しているか
  • opset_versionが新しいか(16以上推奨)

よくある修正例

NG(2次元)

input = torch.randn(77, 768)

OK(3次元)

input = torch.randn(1, 77, 768)

まとめ

今回のエラーの本質はシンプルです。

MultiHeadAttentionに渡すqueryの次元が間違っている

そして、その背景には:

  • ONNX変換時の設定不足
  • dynamic_axes未指定

といった問題があります。


最後に

ONNXは「一度変換すればどこでも動く」強力な仕組みですが、その分shapeの厳密さが求められます。

特にTransformer系モデルでは:

  • 次元(dimension)
  • dynamic_axes
  • opset_version

この3つを意識することで、多くのエラーは回避できます。