Pytorch モデルを fp16 で動作させる方法
カテゴリ:deeplearning
pytorch のモデルは half() を呼び出せば fp16 になる。しかし、LayerNormalization などのレイヤーは非常に小さな値を扱うので、fp32 で動作させた方が良い。
LayerNorm を fp32 に戻す方法
model = # モデルのロード model = model.half() # fp16 に変換 def cast_ln_to_fp32(m): # PyTorchの LayerNorm のインスタンスであれば if isinstance(m, torch.nn.LayerNorm): m.float() # LayerNorm の重みを FP32 にキャスト # 特定のレイヤーを fp32 に戻す model.apply(cast_ln_to_fp32) ''' その他のコード ''' with torch.no_grad(): # AMP の使用は必須ではないが、精度の必要な計算を自動で fp32 で計算するので使用した方が良い with torch.autocast(device_type="cuda", dtype=torch.float16): model() # 推論
pytorch は元のウェイトの情報を RAM 上に保存している。m.float() が実行されたときにその元の精度のウェイトがロードされるので、half() 実行後に精度を戻しても情報は失われない。