dskjal
広告
広告

Pytorch モデルを fp16 で動作させる方法

カテゴリ:deeplearning

pytorch のモデルは half() を呼び出せば fp16 になる。しかし、LayerNormalization などのレイヤーは非常に小さな値を扱うので、fp32 で動作させた方が良い。

LayerNorm を fp32 に戻す方法

model = # モデルのロード
model = model.half()    # fp16 に変換

def cast_ln_to_fp32(m):
    # PyTorchの LayerNorm のインスタンスであれば
    if isinstance(m, torch.nn.LayerNorm):
        m.float() # LayerNorm の重みを FP32 にキャスト

# 特定のレイヤーを fp32 に戻す
model.apply(cast_ln_to_fp32)

'''
その他のコード
'''

with torch.no_grad():
    # AMP の使用は必須ではないが、精度の必要な計算を自動で fp32 で計算するので使用した方が良い
    with torch.autocast(device_type="cuda", dtype=torch.float16):
        model() # 推論

pytorch は元のウェイトの情報を RAM 上に保存している。m.float() が実行されたときにその元の精度のウェイトがロードされるので、half() 実行後に精度を戻しても情報は失われない。


広告
広告

カテゴリ