dskjal
広告
広告

ComfyUI の CFGNorm・Adaptive Projected Guidanceノードとは何か

カテゴリ:deeplearning

Adaptive Projected Guidance

CFG をデノイズ画像と直交する方向と平行方向とに分解したとき、平行方向ベクトルが彩度を高める。なので直交方向のみスケールを効かせることで高い CFG での高彩度化を抑えるのが Adaptive Projected Guidance

解説記事は過飽和を防ぐためにCFGの代わりにAPGAPG guidanceノードで高CFG値で生成する!@ComfyUI x Chromaモデルを参照。

ソースは confy_extras/nodes_apg.py にある。

CFGNorm

STIV: Scalable Text and Image Conditioned Video Generation の Appendix の A. Joint Image-Text Classifier-free Guidance の CFG-Renormalization が元と考えられる。

以下の式のハットありの F は CFG 適用後のデノイズ結果。

\[ \begin{split} \hat{\mathbf{\mathit{F}}}_\theta(\mathbf{\mathit{x}}_t, \mathbf{\mathit{c}}_T, \mathbf{\mathit{c}}_I, t) &=\mathbf{\mathit{F}}_\theta(\mathbf{\mathit{x}}_t, \varnothing, \varnothing, t) + \omega \cdot (\mathbf{\mathit{F}}_\theta(\mathbf{\mathit{x}}_t, \mathbf{\mathit{c}}_T, \mathbf{\mathit{c}}_I, t) - \mathbf{\mathit{F}}_\theta(\mathbf{\mathit{x}}_t, \varnothing, \varnothing, t))\\ \tilde{\mathbf{\mathit{F}}_\theta}(\mathbf{\mathit{x}}_t, \mathbf{\mathit{c}}_T, \mathbf{\mathit{c}}_I, t) &= ||\mathbf{\mathit{F}}_\theta(\mathbf{\mathit{x}}_t, \mathbf{\mathit{c}}_T, \mathbf{\mathit{c}}_I, t)||\dfrac{\hat{\mathbf{\mathit{F}}_\theta}(\mathbf{\mathit{x}}_t, \mathbf{\mathit{c}}_T, \mathbf{\mathit{c}}_I, t)}{||\hat{\mathbf{\mathit{F}}_\theta}(\mathbf{\mathit{x}}_t, \mathbf{\mathit{c}}_T, \mathbf{\mathit{c}}_I, t)||} \end{split} \]

ComfyUI の CFGNorm ノードは、単純な正規化で CFG の効きすぎを抑制するノード。

# comy_extras/nodes_cfy.py
class CFGNorm(io.ComfyNode):
    @classmethod
    def define_schema(cls) -> io.Schema:
        return io.Schema(
            node_id="CFGNorm",
            category="advanced/guidance",
            inputs=[
                io.Model.Input("model"),
                io.Float.Input("strength", default=1.0, min=0.0, max=100.0, step=0.01),
            ],
            outputs=[io.Model.Output(display_name="patched_model")],
            is_experimental=True,
        )

    @classmethod
    def execute(cls, model, strength) -> io.NodeOutput:
        m = model.clone()
        def cfg_norm(args):
            cond_p = args['cond_denoised']  # プロンプトを使ったデノイズ結果
            pred_text_ = args["denoised"]   # CFG 適用後のデノイズ結果

            # SDXL で解像度が 512 でバッチが 1 の時 cond_p.shape = pred_text_.shape = (1, 4, 64, 64)
            norm_full_cond = torch.norm(cond_p, dim=1, keepdim=True)    # チャンネル単位正規化
            norm_pred_text = torch.norm(pred_text_, dim=1, keepdim=True)    # チャンネル単位正規化
            scale = (norm_full_cond / (norm_pred_text + 1e-8)).clamp(min=0.0, max=1.0)
            return pred_text_ * scale * strength

        m.set_model_sampler_post_cfg_function(cfg_norm)
        return io.NodeOutput(m)

広告
広告

カテゴリ