- はじめに
- 因果探索の伝統的アプローチとその限界
- 連続最適化としての因果探索:発想の転換
- NOTEARSの理論:非巡回性の微分可能な特性付け
- NOTEARSアルゴリズム:実装と最適化
- NOTEARSの拡張:非線形モデルへの適用
- NOTEARSの理論的保証と実践的性能
- 最新の発展:NOTEARS以降の微分可能因果探索
- おわりに
はじめに
因果関係の探索、特に有向非巡回グラフ(DAG: Directed Acyclic Graph)の推定は、機械学習や統計学における重要な課題です。DAGは変数間の因果関係を矢印で表現するグラフであり、Bayesianネットワークとも呼ばれています。従来、DAGの学習は組合せ最適化問題として定式化され、非巡回性(acyclicity)の制約を満たすグラフ構造を探索するという非常に難しい問題でした。
なぜ難しいのでしょうか?それは、変数の数が増えるにつれてDAGの可能な構造の数が超指数関数的に増加するためです。例えば、10個の変数だけでも可能なDAG構造は約1018個あります。また、グラフの非巡回性という制約は組合せ的な性質を持ち、連続最適化の枠組みで扱うことが困難でした。
しかし、2018年にZhengらによって提案されたNOTEARS(Non-combinatorial Optimization via Trace Exponential and Augmented lagRangian for Structure learning)は、この難問に革命的なアプローチをもたらしました。NOTEARSの最大の革新点は、離散的な非巡回性制約を連続的な制約に変換したことにあります。これにより、DAG学習問題を標準的な連続最適化問題として解けるようになりました。
本記事では、微分可能因果探索の基本的な考え方から始め、NOTEARSアルゴリズムの理論と実装について詳しく解説します。さらに、NOTEARSから派生した最新の研究動向についても紹介します。
因果探索の伝統的アプローチとその限界
従来の因果グラフ探索は、以下のような離散的な最適化問題として定式化されていました:
これらの課題に対処するため、様々なアルゴリズムが開発されてきました。順序ベースの手法では変数の順序付けを探索し、その順序に基づいてDAGを構築します。貪欲探索法では局所的な改善を繰り返し行います(例:Greedy Equivalence Search)。またスコアベースの手法ではスコア関数を最適化するグラフ構造を探索します。
これらの手法は一定の成功を収めていますが、高次元データや複雑なグラフ構造では依然として計算効率や精度の面で課題がありました。
連続最適化としての因果探索:発想の転換
NOTEARSの革新的なアイデアは、離散的なDAG探索問題を連続最適化問題に変換することです。この転換の鍵となるのは、グラフの隣接行列とその性質です。
構造方程式モデル(SEM)と隣接行列
NOTEARSは線形構造方程式モデル(SEM)を前提としています:
問題の再定式化
NOTEARSは、DAG学習問題を以下のように再定式化します:
NOTEARSの理論:非巡回性の微分可能な特性付け
Zhengらは、以下の関数がこれらの条件を満たすことを示しました:
この関数が非巡回性をどのように特徴付けるのか、直観的に理解してみましょう。
非巡回性の連続的表現:直観的理解
行列指数関数の定義から:
このトレースをとると:
この特性付けにより、離散的なDAG制約を連続的な等式制約[tex(W) = 0]に置き換えることができます。この制約は微分可能であり、その勾配も簡単に計算できます:
これにより、問題は以下の等式制約付き連続最適化問題に変換されます:
この問題は、拡張ラグランジュ法や罰則法などの標準的な最適化手法で解くことができます。
NOTEARSアルゴリズム:実装と最適化
NOTEARSアルゴリズムは、上記の等式制約付き最適化問題を拡張ラグランジュ法を用いて解きます。拡張ラグランジュ関数は以下のように定義されます:
閾値処理ステップは、推定された行列の小さな値(ノイズに起因する可能性がある)をゼロにすることで、より明確なグラフ構造を得るために重要です。
アルゴリズムの実装詳細
NOTEARSの実装は非常にシンプルで、標準的な数値最適化ライブラリを使用して約50行のPythonコードで実現できます。以下はアルゴリズムの核心部分を表すコードの概要です。
def notears(X, lambda1, loss_type='l2'): """ NOTEARSアルゴリズムの実装 Parameters ---------- X : numpy.ndarray n×dのサンプルデータ行列 lambda1 : float L1正則化パラメータ loss_type : str 損失関数のタイプ ('l2' または 'logistic') Returns ------- W_est : numpy.ndarray 推定された重み行列(隣接行列) """ n, d = X.shape # 最適化問題の初期値 W_est = np.zeros((d, d)) # 目的関数の定義 def _loss(W): if loss_type == 'l2': return 0.5 / n * np.square(np.linalg.norm(X - X @ W, 'fro')) elif loss_type == 'logistic': # ロジスティック回帰損失の実装(省略) pass # 非巡回性制約関数の定義 def _h(W): M = W * W # アダマール積 E = scipy.linalg.expm(M) # 行列指数関数 h = np.trace(E) - d return h # 非巡回性制約の勾配 def _grad_h(W): M = W * W E = scipy.linalg.expm(M) return E.T * (2 * W) # 拡張ラグランジュ法のパラメータ alpha = 0 rho = 1.0 h_tol = 1e-8 rho_max = 1e+16 max_iter = 20 # 拡張ラグランジュ法の実行 for _ in range(max_iter): # 部分問題の解法 def _objective(w): W = w.reshape((d, d)) loss = _loss(W) + lambda1 * np.sum(np.abs(W)) h_val = _h(W) obj = loss + alpha * h_val + 0.5 * rho * h_val * h_val return obj # 部分問題を解く(例:L-BFGS) w_new = scipy.optimize.minimize( _objective, W_est.flatten(), method='L-BFGS-B').x W_new = w_new.reshape((d, d)) # DAG制約の評価 h_new = _h(W_new) # 収束判定 if h_new <= h_tol: W_est = W_new break # ラグランジュ乗数の更新 alpha = alpha + rho * h_new rho = min(rho_max, 10 * rho) W_est = W_new # 小さな値を閾値処理 W_est[np.abs(W_est) < 0.3] = 0 return W_est
NOTEARSの拡張:非線形モデルへの適用
元のNOTEARSは線形SEMを前提としていますが、この考え方は非線形モデルにも拡張できます。2019年にZhengらによって提案されたNonlinear NOTEARSは、多層パーセプトロン(MLP)などの非線形関数を用いて変数間の関係をモデル化します。
NOTEARSの理論的保証と実践的性能
実践的には、NOTEARSは従来のGreedy Equivalence Search(GES)などの手法と比較して、特に変数の数が多い場合や、グラフの次数(各ノードの接続数)が大きい場合に優れた性能を示します。特に、スケールフリーグラフなどの複雑な構造を持つグラフの学習において、NOTEARSの優位性が顕著です。
最新の発展:NOTEARS以降の微分可能因果探索
NOTEARSの成功以降、微分可能因果探索の分野は急速に発展しています。以下に最新の研究動向をいくつか紹介します。
DAGMA: DAG Structure Learning via Matrix Exponential
DAGMAは、NOTEARSをさらに効率化したアルゴリズムで、2021年にYuらによって提案されました。DAGMAは行列指数関数ベースの勾配降下法を用いて、より高速かつ正確にDAGを学習します。特に大規模グラフの学習において計算効率が大幅に向上しています。
GOLEM: Greedy Optimization of the Evidence Lower Bound
GOLEMは、ベイズ的アプローチを採用したDAG学習法で、2020年にNgらによって提案されました。変分推論を用いて証拠下限(ELBO)を最適化することで、不確実性を考慮したDAG学習を可能にします。
DYNOTEARS: ダイナミックネットワーク学習への拡張
DYNOTEARSは、NOTEARSを時系列データに拡張したもので、2020年にPamfilらによって提案されました。時間的依存関係を考慮したDAG学習を可能にします。
グラフニューラルネットワークを用いたDAG学習
最近の研究では、グラフニューラルネットワーク(GNN)を用いたDAG学習手法も提案されています。例えば、GraN-DAGはLachapelleらによって2020年に提案され、複雑な非線形関係を捉えることができます。
介入データを用いたDAG学習
純粋な観測データだけでなく、介入データ(介入実験から得られたデータ)を活用したDAG学習手法も発展しています。例えば、DCDI(Differentiable Causal Discovery from Interventional Data)は、介入データを効率的に利用するための微分可能なフレームワークを提供します。
おわりに
NOTEARSは、従来の組合せ最適化に基づくDAG学習から、連続最適化に基づくアプローチへのパラダイムシフトをもたらしました。その中核となるアイデアは、非巡回性という離散的な制約を、行列指数関数を用いた滑らかな制約に変換したことにあります。
この革新的なアプローチにより、より大規模なグラフの学習や、より複雑なモデルの導入が可能になりました。pytorchなどで組めばGPUによる並列化も可能なため大規模なグラフの学習がより可能となります。個人的に因果探索をするときは、参考程度に一旦俯瞰して因果グラフを眺めたい場合がほとんどなので、大規模なデータをある程度高速に学習可能で、工夫を凝らせば異なる変数形(連続値、離散値、二値、カテゴリ)でもそのまま入れて学習することが可能と考えると、因果探索を扱う手段としては第一候補に挙がるかなと感じました。実際に因果探索を実行してみた感じもいい感じだったので、この方向性で色々遊んでみたいなと思っています。
最後までお読みいただきありがとうございます!