ComfyUI の CFGNorm・Adaptive Projected Guidanceノードとは何か
カテゴリ:deeplearning
Adaptive Projected Guidance
CFG をデノイズ画像と直交する方向と平行方向とに分解したとき、平行方向ベクトルが彩度を高める。なので直交方向のみスケールを効かせることで高い CFG での高彩度化を抑えるのが Adaptive Projected Guidance。
解説記事は過飽和を防ぐためにCFGの代わりにAPGやAPG guidanceノードで高CFG値で生成する!@ComfyUI x Chromaモデルを参照。
ソースは confy_extras/nodes_apg.py にある。
CFGNorm
STIV: Scalable Text and Image Conditioned Video Generation の Appendix の A. Joint Image-Text Classifier-free Guidance の CFG-Renormalization が元と考えられる。
以下の式のハットありの F は CFG 適用後のデノイズ結果。
\[ \begin{split} \hat{\mathbf{\mathit{F}}}_\theta(\mathbf{\mathit{x}}_t, \mathbf{\mathit{c}}_T, \mathbf{\mathit{c}}_I, t) &=\mathbf{\mathit{F}}_\theta(\mathbf{\mathit{x}}_t, \varnothing, \varnothing, t) + \omega \cdot (\mathbf{\mathit{F}}_\theta(\mathbf{\mathit{x}}_t, \mathbf{\mathit{c}}_T, \mathbf{\mathit{c}}_I, t) - \mathbf{\mathit{F}}_\theta(\mathbf{\mathit{x}}_t, \varnothing, \varnothing, t))\\ \tilde{\mathbf{\mathit{F}}_\theta}(\mathbf{\mathit{x}}_t, \mathbf{\mathit{c}}_T, \mathbf{\mathit{c}}_I, t) &= ||\mathbf{\mathit{F}}_\theta(\mathbf{\mathit{x}}_t, \mathbf{\mathit{c}}_T, \mathbf{\mathit{c}}_I, t)||\dfrac{\hat{\mathbf{\mathit{F}}_\theta}(\mathbf{\mathit{x}}_t, \mathbf{\mathit{c}}_T, \mathbf{\mathit{c}}_I, t)}{||\hat{\mathbf{\mathit{F}}_\theta}(\mathbf{\mathit{x}}_t, \mathbf{\mathit{c}}_T, \mathbf{\mathit{c}}_I, t)||} \end{split} \]ComfyUI の CFGNorm ノードは、単純な正規化で CFG の効きすぎを抑制するノード。
# comy_extras/nodes_cfy.py
class CFGNorm(io.ComfyNode):
@classmethod
def define_schema(cls) -> io.Schema:
return io.Schema(
node_id="CFGNorm",
category="advanced/guidance",
inputs=[
io.Model.Input("model"),
io.Float.Input("strength", default=1.0, min=0.0, max=100.0, step=0.01),
],
outputs=[io.Model.Output(display_name="patched_model")],
is_experimental=True,
)
@classmethod
def execute(cls, model, strength) -> io.NodeOutput:
m = model.clone()
def cfg_norm(args):
cond_p = args['cond_denoised'] # プロンプトを使ったデノイズ結果
pred_text_ = args["denoised"] # CFG 適用後のデノイズ結果
# SDXL で解像度が 512 でバッチが 1 の時 cond_p.shape = pred_text_.shape = (1, 4, 64, 64)
norm_full_cond = torch.norm(cond_p, dim=1, keepdim=True) # チャンネル単位正規化
norm_pred_text = torch.norm(pred_text_, dim=1, keepdim=True) # チャンネル単位正規化
scale = (norm_full_cond / (norm_pred_text + 1e-8)).clamp(min=0.0, max=1.0)
return pred_text_ * scale * strength
m.set_model_sampler_post_cfg_function(cfg_norm)
return io.NodeOutput(m)