ベイズ決定木系モデルの理論(特にBART)

はじめに

データ解析において決定木は、特徴量空間を領域に区切り各領域で予測を行うシンプルで解釈しやすいモデルとして広く使われています。しかし、従来の決定木アルゴリズムCARTなど)は貪欲法による学習のため局所解に陥りやすく、得られるモデルは不確実性の評価が困難であるという課題があります。また、決定木単体では過学習しやすい傾向も知られています。こうした問題に対し、ベイズ的手法を導入することでモデル構造に対する事前知識を組み込み、統計的な不確実性の定量化やモデル平均化による汎化性能の向上を図ることができます。特に近年では、ベイズ的なアンサンブル学習であるBART(Bayesian Additive Regression Trees、ベイズ加法回帰木)が高い予測精度とロバスト性を示し、因果推論など幅広い応用で注目を集めています。本記事では、ベイズ決定木の基礎理論から代表的な手法、推論アルゴリズム、数学的導出、最新の研究動向、さらに理解を深めるための実装例までを網羅的に解説します。読者は多少の数学的素養を仮定しますが、直感的な説明を交えますので、数式だけに頼らず文章から概念を掴んでいただける構成を目指しています。

基礎理論(ベイズ統計と決定木の関係)

まず、決定木モデルをベイズ的に扱うとはどういうことかを整理します。決定木モデルでは、データ空間を複数の領域(リージョン)に適応的に分割し、各領域でシンプルな予測(例えば一定値による回帰やクラス確率による分類)を行います。このモデル構造自体(すなわち木構造や各ノードの分割ルール、および終端ノード(リーフ)の出力パラメータ)に対して確率モデルを定義するのがベイズ的アプローチです。具体的には、決定木の構造を表すパラメータを T、終端ノードのパラメータを \Theta とすると、事前分布 P(T,\Theta) を与え、観測データ D に対する尤度 P(D\mid T,\Theta) と組み合わせて事後分布 P(T,\Theta \mid D) を考えます。

P(T,\Theta\mid D)\propto P(D\mid T,\Theta)P(T,\Theta)

ここで、P(T,\Theta) はモデルに対する我々の事前の信念を符号化し、複雑な木構造にペナルティを課すことで過度に複雑なモデルを抑制する正則化効果を持ちます。一方、P(D\mid T,\Theta) は与えられた木モデルがデータをどれだけよく説明するかを表すもので、従来の決定木の「損失関数」に対応します。ベイズ的枠組みでは、この事後分布そのものが分析の対象となり、最尤の木構造だけでなく不確実性を含めたモデル平均効果を考慮できる点が重要です。
例えば予測においては、事後分布による事後予測分布
P(y_{\text{new}}\mid x_{\text{new}},D)=\int P(y_{\text{new}}\mid x_{\text{new}},T,\Theta)P(T,\Theta\mid D)dT\,d\Theta

を用いることで、未知入力 x_{\text{new}} に対する予測 y_{\text{new}} の分布が得られます。これは事後分布を用いたベイズモデル平均化であり、決定木モデルの不安定さを緩和し予測精度を向上させる効果があります。さらに、ベイズ統計ではパラメータを確率的に扱いますが、決定木の場合の「パラメータ」には連続値の終端ノード出力だけでなく、木の形そのものが含まれるため、事前分布の工夫や計算手法が非常に重要になります。

すなわち、決定木の持つ柔軟性(任意の関数形状への適合能力)をベイズ推論で包み込むことで、データに対する過剰適合を防ぎつつ高い表現力を維持することが可能になります。事前分布 P(T,\Theta) の具体例としては、まず木構造 T と終端ノードパラメータ \Theta に分けて独立に与えるのが一般的です(事前独立性)。また、木構造 T に対する典型的な事前分布では、木が深くなるほど発生確率が低くなるという仕組みを組み込みます。例えば各ノードが分割(内部ノード)となる確率を深さ d の関数 \alpha (1+d)^{-\beta} とする方法が挙げられます。

代表的な手法(ベイズ回帰木、ベイズ適応的パーティショニング、BARTなど)

ベイズ的な決定木アプローチにはいくつかの代表的手法があり、それぞれ特徴的なモデル化の工夫があります。本節では主な手法としてベイズ回帰木(単一のベイズ決定木モデル)、ベイズ適応的パーティショニング、そしてBART(ベイズ加法回帰木)を取り上げ、それぞれの概要と違いを説明します。

ベイズ回帰木(Bayesian CART

ベイズ回帰木とは、1本の決定木に対してベイズ推論を行う手法で、いわば「ベイズCART」とも言えるものです。Chipman らによる1998年の古典的研究では、CART モデル全体の空間に事前分布を定義し、確率的サーチによって高い事後確率を持つ木構造を探索するアプローチが提案されました。この手法では、前節で述べたように「小さな木を事前に優先する」確率モデルを用い、データによる尤度と組み合わせて木構造の事後分布を評価します。得られた事後分布は単一の最良木だけでなく不確実性も表現しており、例えば新たなデータ点に対する予測では、事後分布に基づくモデル平均予測を行うことで、複数の木構造にわたる予測のばらつきを考慮できます。また、ベイズ回帰木では木構造過学習抑制が事前分布によって為される点も重要です。通常のCARTではプリーニングや木の深さ制限などで過学習を防ぎますが、ベイズ回帰木では、例えば \alpha\beta のハイパーパラメータによる分割確率減衰が、深い木に対する事前確率を指数的に小さくし、同時に各終端ノードのパラメータに対して事前分布による縮小(シュリンク)を掛けることで、必要以上に細かく分割することを抑え、木全体の予測が極端な値を取らないようにします。

ベイズ適応的パーティショニング(Bayesian Adaptive Partitioning)

ベイズ適応的パーティショニングは、決定木に限らずデータ空間の領域分割によって予測モデルを構築するという発想をベイズ的に推し進めた手法です。Holmes や Denison らによる1990年代末の研究では、入力空間を適応的に長方形領域に分割し、各領域で定数や線形モデルを当てはめるパーティションモデルにベイズ法を適用する枠組みが検討されました。これは決定木による分割と類似していますが、必ずしも木構造入れ子の二分割)に限定せず、柔軟な分割を探索できる点が特徴です。そのため、「適応的パーティショニング」と呼ばれ、データ空間全体に対して一様な事前を置くのではなく、領域の数に対してポアソン分布やその他の事前分布を設定することで、モデルの複雑さを自動制御するアプローチも考案されています。

BART(Bayesian Additive Regression Trees)

BART は、未知の回帰関数 f(x)f(x)\approx\sum_{j=1}^{m}g_j(x) と近似し、複数の弱い決定木のアンサンブルによって予測を行う手法です。BART は単一の木ではなく、複数の木の和を用いることで、非線形な関数関係を柔軟に近似できるとともに、各木の寄与を事前分布で正則化することで過学習を防ぎます。具体的には、観測 y

y=\sum_{j=1}^{m}g_j(x)+\epsilon

ϵ∼N(0,σ^2)BART の学習では、各木の構造と終端ノードのパラメータが MCMC により更新され、未知入力 x_{\text{new}} に対する予測は、各木の予測を合算した事後平均として求められます。これにより、BART は単一の木に比べ高い予測精度と不確実性の定量化が可能となります。

推論アルゴリズムMCMC、スパイク&スラブ、ディリクレ過程との関係)

ベイズ決定木モデルの学習(事後分布の計算)は解析的に求めることが困難なため、一般には MCMC(Markov chain Monte Carlo)アルゴリズムによる近似推論が用いられます。単一の決定木モデルの場合、可逆ジャンプ MCMC(RJMCMC)を用いて、以下のような木操作の提案をランダムに行い、それをメトロポリスヘイスティングス法で受理または拒否します。

たとえば、「成長 (grow)」ではランダムに選んだ終端ノードを内部ノードに変え、そこに二分割のルールと 2 つの新たな終端ノードを追加します。「剪定 (prune)」ではランダムに選んだ内部ノードを終端ノードに戻し、その子孫ノードを削除します。「変更 (change)」では既存の内部ノードの分割変数または閾値の値を変更し、「スワップ (swap)」では分割の階層を入れ替える操作を行います。

これらの操作により、現在の木構造 T から新たな構造 T' への遷移が行われ、提案前後のモデルの事後確率比
\alpha=\min\Bigg(1,\frac{P(D\mid T',\Theta')P(T',\Theta')q(T'\to T)}{P(D\mid T,\Theta)P(T,\Theta)q(T\to T')}\Bigg)
が受理されます。(ここで q(\cdot\to\cdot) は提案確率であり、必要に応じてヤコビアン補正も含まれます。)
BART の場合は、複数の決定木からなるモデル全体を更新するため、各木を交互に更新するバックフィッティング方式、すなわちギブスサンプリングが採用されます。各木 g_j(x) について、全体の予測から他の木の寄与を除いた残差 R_j=y-\sum_{k\neq j}g_k(x) を用い、その残差にフィットする単一木モデルとして更新を行います。共役事前を用いることで、各終端ノードのパラメータ \mu は解析的に更新可能となり、ギブスサンプリングが実現されます。
さらに、高次元データにおいて不要な変数の影響を排除するため、各特徴量の分割利用に対してスパイク&スラブ事前を導入する手法が提案されています。ここでは、各特徴量 j が「有用か否か」を示す指標変数 \gamma_j を導入し、\gamma_j=1 ならその変数は分割に使われやすく、\gamma_j=0 なら使われないとし、各内部ノードの分割変数選択確率に重み \theta_j を割り当て、\theta=(\theta_1,\dots,\theta_p) に対してスパイク&スラブ型の事前分布を与えます。
この結果、事後分布において不要な変数の \theta_j は極端に小さくなり、重要な変数のみが分割に用いられるようになります。また、ディリクレ過程の考えを応用して、分割変数選択の事前分布を柔軟に設定する試みも行われています。ディリクレ過程は無限次元の事前分布として知られ、決定木の複雑なモデル空間を扱う際に、そのエッセンスを取り入れることで、事前としての柔軟性が向上します。

数学的導出(各手法の数学的導出を丁寧に)

本節では、ベイズ決定木のいくつかの核心的な数理について、できるだけ行間を埋める形で導出します。主に回帰木を題材とし、共役事前の下での事後計算や周辺尤度の導出などを示します。読者の理解を助けるため、導出の後に簡潔なコード例も交えて確認していきます。

単純な例での事後分布計算

はじめに、ごく簡単な決定木モデルで事後分布の計算を具体的に行ってみましょう。例えば、特徴量が1次元 x のみで、データを2つの領域に分割するか否かを考えるスタンプ(切り株)モデルを仮定します。これは深さ1の決定木(根が終端か、根が2つの子を持つか)に相当します。根ノードで閾値 c を用いて x を分割し、左リージョンと右リージョンでそれぞれ定数予測 \mu_L, \mu_R を行うモデルか、あるいはどこも分割せず全データをひとつのリージョンで定数予測 \mu_{\text{all}} するモデルのどちらかです。事前分布として、分割するか否かに確率 \alpha(深さ0なので確率 \alpha (1+0)^{-\beta}=\alpha で分割)を与え、閾値は一様分布、また \mu には \mu\sim N(\mu_0,\tau^2) の事前、観測ノイズは既知 \sigma^2 とします。データ集合を D={(x_i,y_i)}_{i=1}^n とします。

このモデルにおける事後確率は、2つの場合(分割なし vs 分割あり)それぞれについて以下で計算できます。まず、分割なしモデル(M_0:全体を1リージョン)では、尤度は各 y_i\mathcal{N}(\mu\_{\text{all}},\sigma^2) に従うとしたときの積になります。ここで \mu_{\text{all}} の事前は \mathcal{N}(\mu_0,\tau^2) です。よって \mu\_{\text{all}} に関する事後は共役な正規分布となり、周辺尤度(\mu\_{\text{all}}積分消去した尤度)は次のように計算できます。データの平均を \bar{y}=\frac{1}{n}\sum\_{i=1}^n y_i、データ分散和を S_y^2=\sum_{i=1}^n (y_i-\bar{y})^2 とすると、

P(D\mid M_0)=\int_{-\infty}^{\infty}\Bigg(\prod_{i=1}^n \frac{1}{\sqrt{2\pi\sigma^2}} \exp\Big(-\frac{(y_i-\mu)^2}{2\sigma^2}\Big)\Bigg)\frac{1}{\sqrt{2\pi\tau^2}} \exp\Big(-\frac{(\mu-\mu_0)^2}{2\tau^2}\Big) d\mu
この積分は標準的な正規-正規モデルの周辺尤度計算であり、結果は閉じた形になります。具体的には、事後平均 \mu_{\text{all}}^{\ast} と事後分散 V_{\text{all}}^{\ast} = \Big(\frac{1}{\tau^2}+\frac{n}{\sigma^2}\Big)^{-1} を求めた上で、

P(D\mid M_0)=\frac{1}{\sqrt{2\pi}^{\,n}}\frac{1}{\sigma^n}\sqrt{\frac{V_{\text{all}}^{*}}{\tau^2}} \exp\Bigg(-\frac{(\bar{y}-\mu_0)^2}{2\Big(\frac{\sigma^2}{n}+\tau^2\Big)} - \frac{S_y^2}{2\sigma^2}\Bigg)

と表されます。

次に、分割ありモデル(M_1閾値 c で左右2リージョン)では、左リージョン L={i: x_i \le c} と右リージョン R={i: x_i>c} にデータが分かれます(それぞれのデータ数を n_L, n_R とする)。各リージョンにそれぞれパラメータ \mu_L, \mu_R があり、y_i \sim N(\mu_L,\sigma^2)i\in L)、y_j \sim N(\mu_R,\sigma^2)j\in R)となります。事前は \mu_L, \mu_R \sim N(\mu_0,\tau^2) 独立です。この場合の周辺尤度は、各リージョンごとに先ほどと同様の積分を行い、それらを乗じることで得られます。すなわち、
P(D\mid M_1,c)=P(D_L\mid M_1)P(D_R\mid M_1)
となります。ここで、左領域の周辺尤度は、左領域の平均 \bar{y}_L、分散和 S_L^2、サンプル数 n_L、事後分散 V_L^{\ast} を用いて、

P(D_L\mid M_1)=\frac{1}{\sqrt{2\pi}^{\,n_L}}\frac{1}{\sigma^{n_L}}\sqrt{\frac{V_L^{*}}{\tau^2}} \exp\Bigg(-\frac{(\bar{y}\_L-\mu_0)^2}{2\Big(\frac{\sigma^2}{n_L}+\tau^2\Big)} - \frac{S_L^2}{2\sigma^2}\Bigg)

、右領域についても同様に求められ、全体の周辺尤度は上記の積となります。さらに、閾値 c に対して一様な事前を置くと、分割ありモデル全体の事後確率は \alpha\frac{1}{N_c}P(D\mid M_1,c) と表現され、分割なしモデルは (1-\alpha)P(D\mid M_0) となります。これらを正規化することで、どのモデルがデータをよりよく説明しているかが決定されます。

また、新たな入力 x_{\text{new}} に対する予測は、学習した事後分布 P(T,\Theta\mid D) を用いて、

P(y\_{\text{new}}\mid x\_{\text{new}},D)=\int P(y\_{\text{new}}\mid x\_{\text{new}},T,\Theta)P(T,\Theta\mid D)dT\,d\Theta

と表され、事後予測分布の期待値は

E[y_{\text{new}}\mid x_{\text{new}},D] E_{(T,\Theta)\sim P(\cdot\mid D)}\Big[E[y_{\text{new}}\mid x_{\text{new}},T,\Theta\Big]]
であり、予測分散は

\mathrm{Var}[y_{\text{new}}\mid x_{\text{new}},D]=E\_{(T,\Theta)\mid D}\Big[\mathrm{Var}(y\_{\text{new}}\mid x\_{\text{new}},T,\Theta)\Big]+\mathrm{Var}_{(T,\Theta)\mid D}\Big(E[y\_{\text{new}}\mid x_{\text{new}},T,\Theta]\Big)

と、モデル内部の誤差とモデル間の不確実性が反映されます。
BART では、未知の回帰関数 f(x)m 本の決定木 g_1(x),\dots, g_m(x) の和で近似し、

f(x)\approx\sum\_{j=1}^{m}g_j(x)

と表現します。観測 y は、
y=\sum_{j=1}^{m}g_j(x)+\epsilon,\qquad \epsilon\sim N(0,\sigma^2)
と仮定され、各木は事前分布によって正則化され、MCMC によるギブスサンプリングを通じてその構造と終端ノードのパラメータが更新され、全体の予測は事後平均化によって行われます。

最新の研究動向

ベイズ決定木に関する研究は近年も活発で、手法の改良や新たな応用が次々と報告されています。まず、従来の不連続な決定木を平滑化するため、各ノードの分割関数をヒンジ型からシグモイド型に変更する「ソフト決定木」が提案され、これにより関数近似が連続的に行われ、勾配法との連携も容易になっています。

さらに、高次元データへの適応として、スパイク&スラブ事前やディリクレ過程を利用して不要な変数を自動選別する拡張モデルが開発されています。これにより、実際に重要な変数のみが分割に用いられるようになり、モデルの複雑さが効果的に制御されます。
また、モデルの解釈性向上のため、各変数の分割頻度の事後分布や、ある変数を固定したときの部分依存プロットにベイズ信用区間を付与する方法など、不確実性付きのモデル解釈を与える試みが進められています。
さらに、大規模データへのスケーラビリティを向上させるため、逐次モンテカルロ法や変分推論、GPU 並列化を活用した近似アルゴリズムが検討され、実用上扱えるデータ規模が拡大しています。
最後に、新たな応用分野として、因果推論におけるバイアス調整や時間系列データへの適用が進められており、BART の柔軟性が幅広い問題解決に寄与する可能性が示唆されています。

実装(数式理解を補助するためのコード例)

ベイズ決定木や BART は、R 言語の dbarts や bartMachine、また Python の PyMC や bartpy などのライブラリで実装可能です。例えば、PyMC を用いた BART モデルの記述例は以下の通りです。
import pymc as pm
with pm.Model() as model:
    # データ x, y が与えられているとする
    μ = pm.BART("μ", X=x, Y=y, m=50)  # 50 本の木の BART 回帰
    σ = pm.HalfNormal("σ", 1.0)
    y_obs = pm.Normal("y_obs", mu=μ, sigma=σ, observed=y)
    trace = pm.sample(1000, chains=4)
また、理論理解のために自作の MCMC をスクラッチ実装する例も示します。以下は、1 次元データに対して成長や剪定の提案をランダムに行い、1000 ステップのサンプルを取る簡単な実装例です(簡略化のため、分割位置はデータ点の中間のみ考慮しています)。
import numpy as np
import math
import random
import copy

# 簡単なデータ例
X = np.array([1, 2, 3, 8, 9, 10])
y = np.array([5.1, 4.9, 5.0, 8.0, 7.9, 8.1])
alpha, beta = 0.9, 2.0
mu0, tau2 = 0.0, 10.0
sigma2 = 0.5**2

class Node:
    def __init__(self, idx):
        self.idx = idx
        self.split_var = None
        self.split_val = None
        self.left = None
        self.right = None
        self.is_leaf = True
        self.mu = 0.0

root = Node(np.arange(len(X)))

def depth(node, current=0):
    return current

def split_candidates_for(node):
    idx = node.idx
    if len(idx) <= 1:
        return []
    x_vals = sorted(X[idx])
    candidates = []
    for i in range(len(x_vals)-1):
        c = 0.5 * (x_vals[i] + x_vals[i+1])
        if np.any(X[idx] <= c) and np.any(X[idx] > c):
            candidates.append(c)
    return candidates

def log_marginal_likelihood(node):
    idx = node.idx
    n_node = len(idx)
    if n_node == 0:
        return 0.0
    y_node = y[idx]
    ybar = np.mean(y_node)
    S_node = np.sum((y_node - ybar)**2)
    return -0.5 * ((ybar - mu0)**2 / (sigma2/n_node + tau2) + S_node/sigma2)

def log_prior(node):
    d = depth(node)
    if node.is_leaf:
        return math.log(1 - alpha*(1+d)**(-beta))
    else:
        return math.log(alpha*(1+d)**(-beta)) - math.log(len(split_candidates_for(node)))

current_root = root
samples = []
for t in range(1000):
    leaves = [current_root] if current_root.is_leaf else [current_root.left, current_root.right]
    if random.random() < 0.5:
        node = random.choice(leaves)
        candidates = split_candidates_for(node)
        if not candidates:
            continue
        c = random.choice(candidates)
        new_left_idx = node.idx[X[node.idx] <= c]
        new_right_idx = node.idx[X[node.idx] > c]
        if len(new_left_idx)==0 or len(new_right_idx)==0:
            continue
        new_node = Node(node.idx)
        new_node.is_leaf = False
        new_node.split_var = 0
        new_node.split_val = c
        new_node.left = Node(new_left_idx)
        new_node.right = Node(new_right_idx)
        new_node.left.mu = np.mean(y[new_left_idx])
        new_node.right.mu = np.mean(y[new_right_idx])
        logp_current = log_prior(node) + log_marginal_likelihood(node)
        logp_new = (log_prior(new_node) + log_prior(new_node.left) + log_prior(new_node.right) +
                    log_marginal_likelihood(new_node.left) + log_marginal_likelihood(new_node.right))
        if math.log(random.random()) < (logp_new - logp_current):
            current_root = new_node
    else:
        if not current_root.is_leaf:
            left = current_root.left
            right = current_root.right
            logp_current = (log_prior(current_root) + log_prior(left) + log_prior(right) +
                            log_marginal_likelihood(left) + log_marginal_likelihood(right))
            pruned = Node(current_root.idx)
            logp_new = log_prior(pruned) + log_marginal_likelihood(pruned)
            if math.log(random.random()) < (logp_new - logp_current):
                current_root = pruned
    samples.append(copy.deepcopy(current_root))
split_points = [s.split_val for s in samples if not s.is_leaf]
print("サンプルされた分割位置の平均:", np.mean(split_points))
新規入力 x_{\text{new}} に対する予測分布の近似は、各 MCMC サンプルの木構造においてルートから順に分割条件を辿ることで求めます。以下のコード例では、入力 7.5 に対して各サンプルの終端ノードの予測値を集め、その平均と標準偏差から予測分布を評価しています。
# 新規入力に対する予測分布の近似
x_new = 7.5
preds = []
for tree in samples[500:]:
    node = tree
    while not node.is_leaf:
        if x_new <= node.split_val:
            node = node.left
        else:
            node = node.right
    preds.append(node.mu)
print("予測平均:", np.mean(preds), "予測標準偏差:", np.std(preds))

BART の理論的性質の概略証明

BART の理論的性質に関する研究では、BART の変種(例えばスパース正則化やソフトな分割関数を導入したモデル)について、事後分布が真の回帰関数に対してミニマックス最適な収束レートで集中することが示されています。直感的には、BART は高次元環境においても不要な変数の影響を自動的に除去し、関数の滑らかさに適応して学習する性質を持つため、真の関数に効率よく収束することが保証されます。

証明の鍵は、決定木による関数近似が局所的な区分定数関数の集合として任意の連続関数を一様近似可能であるという事実に基づいており、BART の事前分布が各木の寄与を十分に小さく保つことで、全体として滑らかな関数近似を実現する点にあります。さらに、スパイク&スラブ事前やソフトスプリット(シグモイド関数による分割表現)を導入することで、事前がほとんどゼロの勾配を持つ部分関数や、特定の変数に依存しない関数に質量を集中させることができ、真の関数が滑らかかつスパースであればその周辺に事後が集中しやすくなることが示されています。
これらの結果は、十分な木の本数 m と適切な事前ハイパーパラメータの設定の下で、BART が統計的に有効な推定器であることを保証するものです。証明の詳細は高度なベイズパラメトリック理論に踏み込む必要がありますが、BART のモデル空間が任意の連続関数を一様に近似可能であり、事後分布が真の関数近傍に集中する(後方集中性)ことを示すものであり、これにより BART の予測が高い信頼性を持つことが理論的に裏付けられています。

最後に

今回は、ベイズ的な決定木手法について、基礎理論から代表的な手法(ベイズ回帰木、ベイズ適応的パーティショニング、BART)、推論アルゴリズム、数学的導出、BART の理論的性質の概略証明、最新の研究動向、さらに理解を深めるための実装例までを網羅的に解説しました。決定木モデルに対して事前分布を与え、MCMC などの推論アルゴリズムを用いて事後分布を求めることで、従来の決定木の不安定性や過学習を抑制し、事後平均化による予測および不確実性の評価が可能となります。BART のようなアンサンブル手法は、その柔軟性と高い予測精度により、現代の機械学習モデルとして非常に有用です。今後も理論のさらなる発展と計算手法の高速化により、ベイズ決定木はより広範な分野で活躍することが期待されます。