dskjal
広告
広告

Advantage Weighted Matching のアルゴリズム

カテゴリ:deeplearning

Advantage Weighted Matching: Aligning RL with Pretraining in Diffusion Models

学習効率のいい GRPO 系列の強化学習手法。

論文の背景

DDPO はベースモデルの学習時と強化学習の学習時とで尤度の表現方法が違う。DDPO はノイズを使用した暗黙的な desnoising score matching (DSM) を行っており、このノイズは事前学習では使用されておらず、これが DSM target の分散を増加させる原因となっている。

Advantage Weighted Matching (AWM) では reverse-time descritization を使用せず、報酬シグナルを直接フロー/スコアマッチングに単純かつ効率的に組み込む。これの利点は2つある。

  1. ノイジーなデータによる分散増加を防ぎ、収束を早める
  2. サンプラーとノイズレベルの選択が任意。DDPO では Euler-Maruyama descritization が必須だった

論文の発表順に

  1. DDPO
  2. Flow-GRPO
  3. MixGRPO
  4. AWM

アルゴリズムの概要

  1. 学習モデルで複数の画像を生成する
  2. それぞれの画像に対し報酬モデルで画像を評価する
  3. それぞれの画像に対しノイズを付与し学習モデルでデノイズすることで Score Matching Loss を計算
  4. それぞれの画像に対しノイズを付与しベースモデルでデノイズすることで正則化用 KL 損失を計算
  5. 画像評価と Score Matching Loss とを掛け算したもの+ KL 損失とで勾配を計算

疑似コード

# B はバッチサイズとする
for i in range(num_training_steps):
    latents = sampler(model, prompt) # 任意のサンプラーで複数枚画像を生成
    samples = vae.decode(latents) # 出力次元は [B, 3, H, W]
    rewards = reward_fn(samples) # それぞれの画像の報酬を計算。出力次元は [B]
    advantages = cal_adv(rewards, prompt).detach() # 出力次元は [B]

    noise = randn_like(latents)
    timesteps = get_timesteps(latents)
    noisy_latents = fwd_diffusion(latents, noise, timesteps) # ノイズ付与
    velocity_pred = model(noisy_latents, timesteps, prompt)
    velocity_ref = ref_model(noisy_latents, timesteps, prompt) # kl 損失計算に使う

    log_p = -((velocity_pred - (noise-latents))**2).mean() # Flow Matching Loss
    ratio = torch.exp(log_p - log_p.detach()) # or log_p_old for off-policy update
    policy_loss = -advantages * ratio
    kl_loss = weight(timesteps)*((velocity_pred - velocity_ref)**2).mean()
    loss = policy_loss + beta * kl_loss

def cal_adv(rewards: list) -> list:
    advantage = rewards - rewards.mean()

    # 正規化する場合
    # epsilon = 1e-5  # ゼロ除算エラーよけ
    # advantage = advantage / (advantage.std() + epsilon)

    """ [Pref-GRPO](https://www.arxiv.org/abs/2508.20751) 
        正規化すると分散が小さい場合 advantage が大きくなりすぎて報酬ハッキングが発生する
        なので勝率を advantage として利用すると学習が安定する
    """
    # ブロードキャストで総当たり比較を一括で行う
    # 自分自身との比較も含まれるため、合計から -1 している
    # win_rates = (rewards[:, None] >= rewards).sum(axis=1) - 1
    # advantage = (win_rates - win_rates.mean())/win_rates.std()

    return advantage

DDPO

サンプル(ステップ)ごとに以下のデータを保存しておかなければならないのでメモリ負荷が高い。コードも長く複雑になる。

疑似コード
import torch
from torch import nn, optim
from torch.distributions import Normal

# ---------- ユーティリティ(ユーザが実装) ----------
def sampler_generate_trajectory(model, context, sampler_cfg):
    """
    与えたモデルとコンテキストで1つのサンプル軌跡を生成し、
    各時刻のモデル予測 mean, 実際の action (=サンプラーが選んだノイズ補正)
    とその時刻の log_prob を返す疑似関数。

    実装上:
      - sampler_cfg に timesteps, guidance_weight, variance_schedule 等を入れる
      - 各ステップで model.predict_mean(x_t, t, context, cond_scale=guidance_weight) を呼ぶ
      - 実際の action(次状態の x_{t-1} を決めるためのノイズサンプル)を記録
      - 各 step の log_prob を計算(分散はサンプラーに従う固定値または t に依存)
    """
    trajectory = {
        "contexts": context,
        "steps": [],       # 各要素: dict(mean=..., action=..., log_prob=..., t=...)
        "final_image": None,
    }

    x_t = torch.randn(sampler_cfg["image_shape"])  # 初期ノイズ
    for t in sampler_cfg["timesteps"]:
        # model の出力(t 時刻での予測平均 mu_theta)
        mu = model.predict_mean(x_t, t, context, guidance=sampler_cfg.get("guidance", None))
        sigma = sampler_cfg["sigma_fn"](t)  # 分散(スカラーまたはテンソル)
        # サンプリング(実際のサンプルアクション) -- 実際はサンプラーの式に従う
        noise = torch.randn_like(x_t)
        x_next = mu + sigma * noise

        # ガウスの log_prob(action = x_next treated as sampled from N(mu, sigma^2 I))
        # Flatten して計算すると簡単(実装要注意: shape 対応)
        dist = Normal(loc=mu.reshape(-1), scale=(sigma.reshape(-1)))
        log_prob = dist.log_prob(x_next.reshape(-1)).sum()  # スカラー

        trajectory["steps"].append({
            "t": t,
            "mu": mu.detach(),          # old policy 保存時に使いたければ detach して保存
            "action": x_next.detach(),
            "log_prob": log_prob.detach(),  # サンプリング時に記録(old_logprob)
        })

        x_t = x_next

    trajectory["final_image"] = x_t.detach()  # 最終生成画像(x_0 相当)
    return trajectory

def compute_reward(final_image, context):
    """
    ユーザが定義する黒箱報酬関数:
      - aesthetic score, compressibility, VLM scorer, human 評価の surrogate など
    返り値はスカラー(float tensor)
    """
    # 例: return aesthetic_model.score(final_image, context)
    raise NotImplementedError

# ---------- DDPO 学習ループ(高レベル) ----------
def ddpo_train_loop(model: nn.Module,
                    optimizer: optim.Optimizer,
                    contexts_loader,        # (context のミニバッチを返すイテレータ)
                    sampler_cfg,
                    num_iterations=10000,
                    batch_size=256,
                    variant="score_fn",     # "score_fn" (REINFORCE) or "ppo" (importance sampling / PPO)
                    ppo_clip=0.2,
                    reward_normalize=True):
    """
    高レベルの学習ループ
    - contexts_loader は各イテレーションで batch_size 個の context を返す想定
    - 各イテレーション: ミニバッチのコンテキストについて on-policy サンプリングを行い、
      そのサンプル群で勾配更新を行う(DDPO の基本フロー)
    """
    # 任意で履歴を保持して報酬の正規化を行う(論文では文脈ごとに正規化)
    running_stats = {}  # key=context_id -> {mean, var, count} など(簡略化)

    for it in range(num_iterations):
        trajectories = []
        rewards = []
        old_log_probs = []

        # 1) データ収集(on-policy: 現在の model からサンプル)
        for _ in range(batch_size):
            context = next(contexts_loader)
            traj = sampler_generate_trajectory(model, context, sampler_cfg)
            r = compute_reward(traj["final_image"], context)
            # サンプリング時に保存した各 step の log_prob を sum して軌跡の logprob に
            sum_logprob = sum([s["log_prob"] for s in traj["steps"]])
            trajectories.append(traj)
            rewards.append(r.detach() if isinstance(r, torch.Tensor) else torch.tensor(r))
            old_log_probs.append(sum_logprob.detach())

        rewards = torch.stack(rewards)  # shape [batch_size]

        # 2) 報酬正規化(文脈ごとに normalize するのがよい)
        if reward_normalize:
            # 簡便にバッチ単位での正規化(実運用では context ごとに正規化)
            r_mean = rewards.mean()
            r_std = rewards.std(unbiased=False) + 1e-8
            advantages = (rewards - r_mean) / r_std
        else:
            advantages = rewards

        # 3) 勾配推定とパラメータ更新
        if variant == "score_fn":
            # REINFORCE スタイル: loss = - E[ (sum_t log pi_theta(a_t|s_t)) * advantage ]
            # ここでは on-policy データを使って単純に1ステップ更新
            optimizer.zero_grad()
            losses = []
            for traj, A in zip(trajectories, advantages):
                # 再計算: 各 step の log_prob を現在の model を通して計算し直す
                # (ただし、サンプリング時に記録した action を用いて log_prob を再評価)
                sum_logprob_current = 0.0
                for step in traj["steps"]:
                    t = step["t"]
                    action = step["action"]  # sample として選ばれた次状態
                    # 再度モデルで平均を計算(勾配を通す)
                    mu = model.predict_mean(action, t, traj["contexts"], guidance=sampler_cfg.get("guidance", None))
                    sigma = sampler_cfg["sigma_fn"](t)
                    dist = Normal(loc=mu.reshape(-1), scale=sigma.reshape(-1))
                    lp = dist.log_prob(action.reshape(-1)).sum()
                    sum_logprob_current = sum_logprob_current + lp
                # REINFORCE loss (negative for gradient ascent)
                loss = - sum_logprob_current * A
                losses.append(loss)
            loss = torch.stack(losses).mean()
            loss.backward()
            optimizer.step()

        elif variant == "ppo":
            # Importance-sampling / PPO 風
            # store old_log_probs from sampling step (detached) as log pi_old
            optimizer.zero_grad()
            surrogate_terms = []
            for traj, A, logp_old in zip(trajectories, advantages, old_log_probs):
                # compute current sum_logprob
                sum_logprob_current = 0.0
                for step in traj["steps"]:
                    t = step["t"]
                    action = step["action"]
                    mu = model.predict_mean(action, t, traj["contexts"], guidance=sampler_cfg.get("guidance", None))
                    sigma = sampler_cfg["sigma_fn"](t)
                    dist = Normal(loc=mu.reshape(-1), scale=sigma.reshape(-1))
                    lp = dist.log_prob(action.reshape(-1)).sum()
                    sum_logprob_current = sum_logprob_current + lp

                # ratio = exp(logpi_new - logpi_old)
                ratio = torch.exp(sum_logprob_current - logp_old)
                unclipped = ratio * A
                clipped = torch.clamp(ratio, 1.0 - ppo_clip, 1.0 + ppo_clip) * A
                surrogate = - torch.min(unclipped, clipped)  # negative because we minimize loss
                surrogate_terms.append(surrogate)
            loss = torch.stack(surrogate_terms).mean()
            loss.backward()
            optimizer.step()

        else:
            raise ValueError("unknown variant")

        # (任意) ログ出力
        if it % 10 == 0:
            print(f"it={it} mean_reward={rewards.mean().item():.4f}")

    # end training loop

広告
広告

カテゴリ