微分可能因果探索とNOTEARSの理論

はじめに

因果関係の探索、特に有向非巡回グラフ(DAG: Directed Acyclic Graph)の推定は、機械学習統計学における重要な課題です。DAGは変数間の因果関係を矢印で表現するグラフであり、Bayesianネットワークとも呼ばれています。従来、DAGの学習は組合せ最適化問題として定式化され、非巡回性(acyclicity)の制約を満たすグラフ構造を探索するという非常に難しい問題でした。
なぜ難しいのでしょうか?それは、変数の数が増えるにつれてDAGの可能な構造の数が超指数関数的に増加するためです。例えば、10個の変数だけでも可能なDAG構造は約1018個あります。また、グラフの非巡回性という制約は組合せ的な性質を持ち、連続最適化の枠組みで扱うことが困難でした。
しかし、2018年にZhengらによって提案されたNOTEARS(Non-combinatorial Optimization via Trace Exponential and Augmented lagRangian for Structure learning)は、この難問に革命的なアプローチをもたらしました。NOTEARSの最大の革新点は、離散的な非巡回性制約を連続的な制約に変換したことにあります。これにより、DAG学習問題を標準的な連続最適化問題として解けるようになりました。
本記事では、微分可能因果探索の基本的な考え方から始め、NOTEARSアルゴリズムの理論と実装について詳しく解説します。さらに、NOTEARSから派生した最新の研究動向についても紹介します。

因果探索の伝統的アプローチとその限界

従来の因果グラフ探索は、以下のような離散的な最適化問題として定式化されていました:

\min_{G} Q(G) \quad \text{subject to} \quad G \in \mathcal{D}
ここで、Gはグラフ構造、Q(G)はスコア関数(BIC、MDLなど)、\mathcal{D}はDAGの集合を表します。この問題は複数の理由から解くのが非常に困難です。まず、DAGの空間が離散的で組合せ的であること、次に可能なDAG構造の数が変数の数に対して超指数関数的に増加すること、そして非巡回性制約が扱いにくいことが挙げられます。

これらの課題に対処するため、様々なアルゴリズムが開発されてきました。順序ベースの手法では変数の順序付けを探索し、その順序に基づいてDAGを構築します。貪欲探索法では局所的な改善を繰り返し行います(例:Greedy Equivalence Search)。またスコアベースの手法ではスコア関数を最適化するグラフ構造を探索します。
これらの手法は一定の成功を収めていますが、高次元データや複雑なグラフ構造では依然として計算効率や精度の面で課題がありました。

連続最適化としての因果探索:発想の転換

NOTEARSの革新的なアイデアは、離散的なDAG探索問題を連続最適化問題に変換することです。この転換の鍵となるのは、グラフの隣接行列とその性質です。

構造方程式モデル(SEM)と隣接行列

NOTEARSは線形構造方程式モデル(SEM)を前提としています:

X_j = \sum_{i=1}^{d} W_{ij} X_i + \varepsilon_j
ここで、X = (X_1, \ldots, X_d)は観測変数のベクトル、W \in \mathbb{R}^{d \times d}は重み行列(隣接行列)、\varepsilon_jはノイズ項です。W_{ij} \neq 0は変数iから変数jへの直接の因果関係があることを示します。
この重み行列WがDAGを表現するためには、非巡回性の条件を満たす必要があります。従来の手法では、この非巡回性を保証するために、グラフに新しいエッジを追加するたびに巡回がないかを明示的にチェックする必要がありました。これは計算効率の面で大きなボトルネックとなっていました。

問題の再定式化

NOTEARSは、DAG学習問題を以下のように再定式化します:

\min_{W \in \mathbb{R}^{d \times d}} F(W) \quad \text{subject to} \quad G(W) \in \mathcal{D}
ここで、F(W)はデータに対する適合度とスパース性を評価するスコア関数、G(W)Wから導かれるグラフ、\mathcal{D}はDAGの集合です。典型的には、F(W)は二乗誤差損失と正則化項の組み合わせになります:
F(W) = \frac{1}{2n} \|X - XW\|_F^2 + \lambda \|W\|_1
ここで、X \in \mathbb{R}^{n \times d}n個のサンプルからなるデータ行列、\|\cdot\|_Fはフロベニウスノルム、\|\cdot\|_1L_1ノルム(スパース性を促進)、\lambda正則化パラメータです。
しかし、制約G(W) \in \mathcal{D}は依然として離散的で扱いにくいものです。NOTEARSの最大の貢献は、この離散的な制約を連続的な等式制約に置き換えたことにあります。

NOTEARSの理論:非巡回性の微分可能な特性付け

NOTEARSの中核となるアイデアは、グラフの非巡回性を表現する滑らかな関数h(W)を構築することです。理想的には、この関数は以下の性質を持つべきです。まず、h(W) = 0当且つ当にWが非巡回的(DAG)であること、次にh(W)の値がグラフの「DAG性」を定量化すること、さらにh(W)が滑らか(微分可能)であること、そしてh(W)とその導関数が計算容易であることが挙げられます。

Zhengらは、以下の関数がこれらの条件を満たすことを示しました:

h(W) = \text{tr}(e^{W \circ W}) - d
ここで、\circアダマール積(要素ごとの積)、e^{W \circ W}は行列W \circ Wの行列指数関数、\text{tr}はトレース(対角和)、dは変数の数です。

この関数が非巡回性をどのように特徴付けるのか、直観的に理解してみましょう。

非巡回性の連続的表現:直観的理解

行列S = W \circ Wを考えます。この行列はWのスパース構造を保ちながら、全ての要素を非負にします。任意の正の整数kに対して、S^kの要素(S^k)_{ij}は、ノードiからノードjへのすべての長さkのパスの重みの積の和を表します。
グラフが非巡回的であるとき、ある長さ以上のパスはすべて存在しなくなります。具体的には、d個のノードを持つDAGでは、長さd以上のパスは存在できません。これは、行列S^kk \geq d)のすべての対角要素が0になることを意味します。

行列指数関数の定義から:

e^S = I + S + \frac{1}{2!}S^2 + \frac{1}{3!}S^3 + \cdots

このトレースをとると:

\text{tr}(e^S) = \text{tr}(I) + \text{tr}(S) + \frac{1}{2!}\text{tr}(S^2) + \cdots = d + \text{tr}(S) + \frac{1}{2!}\text{tr}(S^2) + \cdots
ここで、\text{tr}(I) = dです。さらに、\text{tr}(S^k)は長さkのすべての閉路(サイクル)の重みの和と解釈できます。グラフが非巡回的であれば、すべてのk \geq 1に対して\text{tr}(S^k) = 0となります。従って、\text{tr}(e^S) = dとなり、h(W) = \text{tr}(e^S) - d = 0が成立します。
逆に、グラフにサイクルが存在する場合、少なくとも一つのkに対して\text{tr}(S^k) > 0となり、h(W) > 0となります。

この特性付けにより、離散的なDAG制約を連続的な等式制約[tex(W) = 0]に置き換えることができます。この制約は微分可能であり、その勾配も簡単に計算できます:

\nabla h(W) = (e^{W \circ W})^T \circ (2W)

これにより、問題は以下の等式制約付き連続最適化問題に変換されます:

\min_{W \in \mathbb{R}^{d \times d}} F(W) \quad \text{subject to} \quad h(W) = 0

この問題は、拡張ラグランジュ法や罰則法などの標準的な最適化手法で解くことができます。

NOTEARSアルゴリズム:実装と最適化

NOTEARSアルゴリズムは、上記の等式制約付き最適化問題を拡張ラグランジュ法を用いて解きます。拡張ラグランジュ関数は以下のように定義されます:

L_\rho(W, \alpha) = F(W) + \alpha h(W) + \frac{\rho}{2}h(W)^2
ここで、\alphaラグランジュ乗数、\rho > 0はペナルティパラメータです。
アルゴリズムは以下の手順で進行します。まず初期値W_0ラグランジュ乗数\alpha_0を設定します。次に各反復tで、未制約の部分問題W_{t+1} = \arg\min_W L_\rho(W, \alpha_t)を解き、ラグランジュ乗数を更新します:\alpha_{t+1} = \alpha_t + \rho h(W_{t+1})。そして収束判定としてh(W_{t+1}) < \epsilonを満たせば終了します。最後に得られたW閾値処理して最終的なDAG構造を得ます。
部分問題の解法には、L-BFGS(Limited-memory Broyden–Fletcher–Goldfarb–Shanno)アルゴリズムなどの準ニュートン法が効果的です。L_1正則化項がある場合は、近接勾配法や座標降下法などの複合最適化手法が適用できます。

閾値処理ステップは、推定された行列の小さな値(ノイズに起因する可能性がある)をゼロにすることで、より明確なグラフ構造を得るために重要です。

アルゴリズムの実装詳細

NOTEARSの実装は非常にシンプルで、標準的な数値最適化ライブラリを使用して約50行のPythonコードで実現できます。以下はアルゴリズムの核心部分を表すコードの概要です。

def notears(X, lambda1, loss_type='l2'):
    """
    NOTEARSアルゴリズムの実装
    
    Parameters
    ----------
    X : numpy.ndarray
        n×dのサンプルデータ行列
    lambda1 : float
        L1正則化パラメータ
    loss_type : str
        損失関数のタイプ ('l2' または 'logistic')
        
    Returns
    -------
    W_est : numpy.ndarray
        推定された重み行列(隣接行列)
    """
    n, d = X.shape
    # 最適化問題の初期値
    W_est = np.zeros((d, d))
    
    # 目的関数の定義
    def _loss(W):
        if loss_type == 'l2':
            return 0.5 / n * np.square(np.linalg.norm(X - X @ W, 'fro'))
        elif loss_type == 'logistic':
            # ロジスティック回帰損失の実装(省略)
            pass
    
    # 非巡回性制約関数の定義
    def _h(W):
        M = W * W  # アダマール積
        E = scipy.linalg.expm(M)  # 行列指数関数
        h = np.trace(E) - d
        return h
    
    # 非巡回性制約の勾配
    def _grad_h(W):
        M = W * W
        E = scipy.linalg.expm(M)
        return E.T * (2 * W)
    
    # 拡張ラグランジュ法のパラメータ
    alpha = 0
    rho = 1.0
    h_tol = 1e-8
    rho_max = 1e+16
    max_iter = 20
    
    # 拡張ラグランジュ法の実行
    for _ in range(max_iter):
        # 部分問題の解法
        def _objective(w):
            W = w.reshape((d, d))
            loss = _loss(W) + lambda1 * np.sum(np.abs(W))
            h_val = _h(W)
            obj = loss + alpha * h_val + 0.5 * rho * h_val * h_val
            return obj
        
        # 部分問題を解く(例:L-BFGS)
        w_new = scipy.optimize.minimize(
            _objective, W_est.flatten(), 
            method='L-BFGS-B').x
        W_new = w_new.reshape((d, d))
        
        # DAG制約の評価
        h_new = _h(W_new)
        
        # 収束判定
        if h_new <= h_tol:
            W_est = W_new
            break
        
        # ラグランジュ乗数の更新
        alpha = alpha + rho * h_new
        rho = min(rho_max, 10 * rho)
        W_est = W_new
    
    # 小さな値を閾値処理
    W_est[np.abs(W_est) < 0.3] = 0
    
    return W_est
このアルゴリズムの美しさは、標準的な数値最適化ライブラリを利用して、複雑なDAG学習問題を解くことができる点にあります。特に、非巡回性制約の行列指数関数表現が、複雑な組合せ条件を滑らかな関数に変換することで、標準的な勾配ベースの最適化手法を適用可能にしています。

NOTEARSの拡張:非線形モデルへの適用

元のNOTEARSは線形SEMを前提としていますが、この考え方は非線形モデルにも拡張できます。2019年にZhengらによって提案されたNonlinear NOTEARSは、多層パーセプトロンMLP)などの非線形関数を用いて変数間の関係をモデル化します。

X_j = f_j(X_{\text{pa}(j)}) + \varepsilon_j
ここで、f_j非線形関数、X_{\text{pa}(j)}X_jの親変数の集合です。非線形モデルでも、ヤコビアン行列を用いてh(W)と同様の制約を構築することができます。

NOTEARSの理論的保証と実践的性能

NOTEARSアルゴリズムには、いくつかの重要な理論的性質があります。まず、(W) = 0という制約が非巡回性を正確に特徴付けることが証明されています。さらに、適切な正則化の下で、NOTEARSは高次元設定での一致性(consistency)を持ちます。

実践的には、NOTEARSは従来のGreedy Equivalence Search(GES)などの手法と比較して、特に変数の数が多い場合や、グラフの次数(各ノードの接続数)が大きい場合に優れた性能を示します。特に、スケールフリーグラフなどの複雑な構造を持つグラフの学習において、NOTEARSの優位性が顕著です。

最新の発展:NOTEARS以降の微分可能因果探索

NOTEARSの成功以降、微分可能因果探索の分野は急速に発展しています。以下に最新の研究動向をいくつか紹介します。

DAGMA: DAG Structure Learning via Matrix Exponential

DAGMAは、NOTEARSをさらに効率化したアルゴリズムで、2021年にYuらによって提案されました。DAGMAは行列指数関数ベースの勾配降下法を用いて、より高速かつ正確にDAGを学習します。特に大規模グラフの学習において計算効率が大幅に向上しています。

GOLEM: Greedy Optimization of the Evidence Lower Bound

GOLEMは、ベイズ的アプローチを採用したDAG学習法で、2020年にNgらによって提案されました。変分推論を用いて証拠下限(ELBO)を最適化することで、不確実性を考慮したDAG学習を可能にします。

DYNOTEARS: ダイナミックネットワーク学習への拡張

DYNOTEARSは、NOTEARSを時系列データに拡張したもので、2020年にPamfilらによって提案されました。時間的依存関係を考慮したDAG学習を可能にします。

X^{(t)}_j = \sum_{i=1}^{d} W^{(0)}_{ij} X^{(t)}_i + \sum_{i=1}^{d} W^{(1)}_{ij} X^{(t-1)}_i + \varepsilon_j^{(t)}
ここで、W^{(0)}は同時点での因果関係、W^{(1)}は時間的な因果関係を表す重み行列です。

グラフニューラルネットワークを用いたDAG学習

最近の研究では、グラフニューラルネットワーク(GNN)を用いたDAG学習手法も提案されています。例えば、GraN-DAGはLachapelleらによって2020年に提案され、複雑な非線形関係を捉えることができます。

介入データを用いたDAG学習

純粋な観測データだけでなく、介入データ(介入実験から得られたデータ)を活用したDAG学習手法も発展しています。例えば、DCDI(Differentiable Causal Discovery from Interventional Data)は、介入データを効率的に利用するための微分可能なフレームワークを提供します。

おわりに

NOTEARSは、従来の組合せ最適化に基づくDAG学習から、連続最適化に基づくアプローチへのパラダイムシフトをもたらしました。その中核となるアイデアは、非巡回性という離散的な制約を、行列指数関数を用いた滑らかな制約に変換したことにあります。
この革新的なアプローチにより、より大規模なグラフの学習や、より複雑なモデルの導入が可能になりました。pytorchなどで組めばGPUによる並列化も可能なため大規模なグラフの学習がより可能となります。個人的に因果探索をするときは、参考程度に一旦俯瞰して因果グラフを眺めたい場合がほとんどなので、大規模なデータをある程度高速に学習可能で、工夫を凝らせば異なる変数形(連続値、離散値、二値、カテゴリ)でもそのまま入れて学習することが可能と考えると、因果探索を扱う手段としては第一候補に挙がるかなと感じました。実際に因果探索を実行してみた感じもいい感じだったので、この方向性で色々遊んでみたいなと思っています。
最後までお読みいただきありがとうございます!