Pytorch モデルを fp16 で動作させる方法
カテゴリ:deeplearning
pytorch のモデルは half() を呼び出せば fp16 になる。しかし、LayerNormalization などのレイヤーは非常に小さな値を扱うので、fp32 で動作させた方が良い。
LayerNorm を fp32 に戻す方法
model = # モデルのロード
model = model.half() # fp16 に変換
def cast_ln_to_fp32(m):
# PyTorchの LayerNorm のインスタンスであれば
if isinstance(m, torch.nn.LayerNorm):
m.float() # LayerNorm の重みを FP32 にキャスト
# 特定のレイヤーを fp32 に戻す
model.apply(cast_ln_to_fp32)
'''
その他のコード
'''
with torch.no_grad():
# AMP の使用は必須ではないが、精度の必要な計算を自動で fp32 で計算するので使用した方が良い
with torch.autocast(device_type="cuda", dtype=torch.float16):
model() # 推論
pytorch は元のウェイトの情報を RAM 上に保存している。m.float() が実行されたときにその元の精度のウェイトがロードされるので、half() 実行後に精度を戻しても情報は失われない。