DDPO(Diffusion-Direct Preference Optimization)の学習方法
Diffusion-DPO Diffusion Model Alignment Using Direct Preference Optimization は SD3 でも使われた強化学習手法。
ただし、画風や新しい概念の学習のような一般的なタスクは SFT(教師ありファインチューニング)が適している。「○○がうまく描けない」というニーズは○○が言語化できているので SFT を使うべき。SFT はデータセットを用意しやすいし学習負荷も低い。
DDPO を使うケース
- 言語化が困難だが描いてほしくないもの(暴力的・性的表現)がある
- ユーザーの選好データを持っている
SFT(教師ありファインチューニング)と DDPO との違い
| 手法 | 必要なデータ | 学習1回に必要な 推論回数 |
|---|---|---|
| SFT | 教師画像 キャプション | 1回 |
| DDPO | 教師画像 好ましくない画像 キャプション 参照モデル | 4回 |
参照モデルは強化学習前のベースモデル。
DDPO は単純にモデルを2つロードしなければならず、学習1回に必要な推論回数は4回(学習モデルと参照モデルとがそれぞれ教師画像と好ましくない画像とに対しデノイズを行う)。
好ましくない画像の作成
通常はユーザー投票で画像評価を集める。
単に画質を向上したいだけなら、学習したモデルを使って、学習に使ったキャプションで生成してみて、上手く生成できなかった画像を使う方法もある。
損失関数
\[ \mathcal{L}_{DPO} = - \mathbb{E}_{(\mathbf{x}_w, \mathbf{x}_l)\sim \mathcal{D}}\left [\mathrm{log} \sigma \left ( \beta \mathrm{log}\dfrac{P_{\pi_{\theta}}(\mathbf{x}_w|\mathbf{c})}{P_{\pi_{ref}}(\mathbf{x}_w|\mathbf{c})} - \beta \mathrm{log}\dfrac{P_{\pi_{\theta}}(\mathbf{x}_l|\mathbf{c})}{P_{\pi_{ref}}(\mathbf{x}_l|\mathbf{c})} \right ) \right ] \]ややこしいが以下の手順で計算する。
- theta_w_loss = 好ましい画像・キャプション・学習させるモデルを使って計算された損失
- theta_l_loss = 好ましくない画像・キャプション・学習させるモデルを使って計算された損失
- ref_w_loss = 好ましい画像・キャプション・参照モデル(強化学習前のベースモデル)を使って計算された損失
- ref_l_loss = 好ましくない画像・キャプション・参照モデル(強化学習前のベースモデル)を使って計算された損失
reward_diff = (ref_w_loss - theta_w_loss) - (ref_l_loss - theta_l_loss)
dpo_loss = -F.logsigmoid(beta * reward_diff) # beta は選好の強度。固定するか、徐々に上げていく
final_loss = dpo_loss.mean()
上記のコードは数式と符号が一致しないように見えるが、Diffusion-DPOでは 「損失(loss)」ではなく「対数尤度 log-likelihood」ベースで式が定義されており、それを loss に変換する場合は符号処理を注意する必要がある。
損失は小さい方がいいが、対数尤度は大きい方がいいので符号は逆になる。
学習について
学習に使うタイムステップをすべて事前に計算しておいた場合、ref_loss をキャッシュしておける。そうすると参照モデルを VRAM にロードする必要がないので VRAM の節約になる。
参照モデルの loss の計算は独立しているので、クラウドで並列計算すれば時間も節約できる。
そうした場合、モデルの学習に必要なのは教師画像2回の推論なので、ローカルでの実施も現実的。同じ GPU が2台あるなら、theta_w_loss と theta_l_loss を同時に計算することで計算時間を倍速化できる。
サンプルコード
import torch
import torch.nn.functional as F
import random
# --- 変数定義 ---
# model : 学習させる拡散モデル (U-Net) - パラメータθを持つ
# ref_model : 固定された参照モデル (U-Net) - パラメータθ_refを持つ
# x_w : 好ましい画像の潜在表現 (latent)
# x_l : 好ましくない画像の潜在表現 (latent)
# caption : テキストエンコーダーでエンコードされたプロンプトの埋め込み
# beta : DPOのハイパーパラメータ (選好の強度)
def calculate_diffusion_dpo_loss(model, ref_model, x_w, x_l, caption, beta=0.1):
"""
Diffusion-DPOの損失を計算する関数
"""
# 1. ノイズと時間ステップのサンプリング
# 各画像に対して独立した時間ステップとノイズをサンプリングすることが一般的だが、
# シンプルにするため、同じバッチ内の全ペアに単一のタイムステップ t を使用する
batch_size = x_w.shape[0]
# タイムステップ t をランダムにサンプリング
timesteps = torch.randint(0, model.scheduler.config.num_train_timesteps, (batch_size,), device=x_w.device).long()
# 標準正規分布からノイズをサンプリング
noise = torch.randn_like(x_w)
# 2. ノイズを加えた潜在表現 (Noised Latents) の計算
# スケジューラを使用して、潜在表現にノイズを加える
# x_t = scheduler.add_noise(original_samples=x, noise=noise, timesteps=t) に相当
# ここでは簡略化し、ノイズ加算を外部で行うか、U-Netの入力としてノイズ付きの潜在表現を用いると仮定
x_w_noisy = model.scheduler.add_noise(x_w, noise, timesteps)
x_l_noisy = model.scheduler.add_noise(x_l, noise, timesteps)
# 3. ノイズ予測 (U-Netの出力)
# 現在のモデル(policy model)によるノイズ予測
# 拡散モデルはノイズ予測器として機能する: 𝛜_θ(x_t, t, c)
pred_noise_w = model(x_w_noisy, timesteps, caption).sample
pred_noise_l = model(x_l_noisy, timesteps, caption).sample
# 参照モデルによるノイズ予測(参照モデルは勾配計算から外す)
with torch.no_grad():
pred_noise_ref_w = ref_model(x_w_noisy, timesteps, caption).sample
pred_noise_ref_l = ref_model(x_l_noisy, timesteps, caption).sample
# 4. 報酬の計算 (Reward Proxy)
# DPOでは、ノイズ予測誤差の L2 ノルムを用いて、対数尤度比の代理項を計算する。
# Log-Likelihood Ratio (LLR) の代理項: LLR_proxy = -(L2_Loss_Policy - L2_Loss_Ref) / 2 * sigma^2
# 損失関数自体は、以下の Log-Probability Ratio (LPR) に基づく。
# ノイズ予測誤差の二乗ノルム (L2 Loss)
# L2 = || noise - pred_noise ||^2
l2_w = F.mse_loss(noise, pred_noise_w, reduction='none').mean([1, 2, 3]) # 平均化して (Batch,) サイズにする
l2_l = F.mse_loss(noise, pred_noise_l, reduction='none').mean([1, 2, 3])
l2_ref_w = F.mse_loss(noise, pred_noise_ref_w, reduction='none').mean([1, 2, 3])
l2_ref_l = F.mse_loss(noise, pred_noise_ref_l, reduction='none').mean([1, 2, 3])
# 5. 対数確率比 (Log-Probability Ratio, LPR) の計算
# LPRは、政策モデルと参照モデルのノイズ予測誤差の差に基づいて計算される
# LPR_w = -1/(2*sigma^2) * (l2_w - l2_ref_w)
# 拡散スケジューラの分散 term (sigma^2) は、一般的にタイムステップ t に依存する。
# 簡易化のため、ここでは分散項をベータに組み込む、あるいは無視して差分のみを扱う (手法による)
# 論文に従い、LPRの差分を計算する
# LPR_theta(x) = L2_Loss_Ref(x) - L2_Loss_Policy(x)
# LPR_w / LPR_l は、好ましい画像と好ましくない画像に対する報酬項の差分
# 報酬代理 R(x) = -1 * L2_Loss(x) とみなすと、
# R_theta(x) - R_ref(x) = -(L2_Loss_theta(x) - L2_Loss_ref(x))
# DPOの損失の核となるのは、好ましい方と好ましくない方の報酬の差分
reward_diff = (l2_ref_w - l2_w) - (l2_ref_l - l2_l)
# 6. DPO損失の適用
# L_DPO = -log(sigmoid(beta * reward_diff))
dpo_loss = -F.logsigmoid(beta * reward_diff)
# 最終的な損失はバッチ全体での平均
final_loss = dpo_loss.mean()
return final_loss