- はじめに
- 基礎理論(ベイズ統計と決定木の関係)
- 代表的な手法(ベイズ回帰木、ベイズ適応的パーティショニング、BARTなど)
- 推論アルゴリズム(MCMC、スパイク&スラブ、ディリクレ過程との関係)
- 数学的導出(各手法の数学的導出を丁寧に)
- 最新の研究動向
- 実装(数式理解を補助するためのコード例)
- BART の理論的性質の概略証明
- 最後に
はじめに
データ解析において決定木は、特徴量空間を領域に区切り各領域で予測を行うシンプルで解釈しやすいモデルとして広く使われています。しかし、従来の決定木アルゴリズム(CARTなど)は貪欲法による学習のため局所解に陥りやすく、得られるモデルは不確実性の評価が困難であるという課題があります。また、決定木単体では過学習しやすい傾向も知られています。こうした問題に対し、ベイズ的手法を導入することでモデル構造に対する事前知識を組み込み、統計的な不確実性の定量化やモデル平均化による汎化性能の向上を図ることができます。特に近年では、ベイズ的なアンサンブル学習であるBART(Bayesian Additive Regression Trees、ベイズ加法回帰木)が高い予測精度とロバスト性を示し、因果推論など幅広い応用で注目を集めています。本記事では、ベイズ決定木の基礎理論から代表的な手法、推論アルゴリズム、数学的導出、最新の研究動向、さらに理解を深めるための実装例までを網羅的に解説します。読者は多少の数学的素養を仮定しますが、直感的な説明を交えますので、数式だけに頼らず文章から概念を掴んでいただける構成を目指しています。
基礎理論(ベイズ統計と決定木の関係)
まず、決定木モデルをベイズ的に扱うとはどういうことかを整理します。決定木モデルでは、データ空間を複数の領域(リージョン)に適応的に分割し、各領域でシンプルな予測(例えば一定値による回帰やクラス確率による分類)を行います。このモデル構造自体(すなわち木構造や各ノードの分割ルール、および終端ノード(リーフ)の出力パラメータ)に対して確率モデルを定義するのがベイズ的アプローチです。具体的には、決定木の構造を表すパラメータを 、終端ノードのパラメータを
とすると、事前分布
を与え、観測データ
に対する尤度
と組み合わせて事後分布
を考えます。
代表的な手法(ベイズ回帰木、ベイズ適応的パーティショニング、BARTなど)
ベイズ的な決定木アプローチにはいくつかの代表的手法があり、それぞれ特徴的なモデル化の工夫があります。本節では主な手法としてベイズ回帰木(単一のベイズ決定木モデル)、ベイズ適応的パーティショニング、そしてBART(ベイズ加法回帰木)を取り上げ、それぞれの概要と違いを説明します。
ベイズ回帰木(Bayesian CART)
ベイズ適応的パーティショニング(Bayesian Adaptive Partitioning)
BART(Bayesian Additive Regression Trees)
推論アルゴリズム(MCMC、スパイク&スラブ、ディリクレ過程との関係)
ベイズ決定木モデルの学習(事後分布の計算)は解析的に求めることが困難なため、一般には MCMC(Markov chain Monte Carlo)アルゴリズムによる近似推論が用いられます。単一の決定木モデルの場合、可逆ジャンプ MCMC(RJMCMC)を用いて、以下のような木操作の提案をランダムに行い、それをメトロポリス・ヘイスティングス法で受理または拒否します。
数学的導出(各手法の数学的導出を丁寧に)
本節では、ベイズ決定木のいくつかの核心的な数理について、できるだけ行間を埋める形で導出します。主に回帰木を題材とし、共役事前の下での事後計算や周辺尤度の導出などを示します。読者の理解を助けるため、導出の後に簡潔なコード例も交えて確認していきます。
単純な例での事後分布計算
はじめに、ごく簡単な決定木モデルで事後分布の計算を具体的に行ってみましょう。例えば、特徴量が1次元 のみで、データを2つの領域に分割するか否かを考えるスタンプ(切り株)モデルを仮定します。これは深さ1の決定木(根が終端か、根が2つの子を持つか)に相当します。根ノードで閾値
を用いて
を分割し、左リージョンと右リージョンでそれぞれ定数予測
を行うモデルか、あるいはどこも分割せず全データをひとつのリージョンで定数予測
するモデルのどちらかです。事前分布として、分割するか否かに確率
(深さ0なので確率
で分割)を与え、閾値は一様分布、また
には
の事前、観測ノイズは既知
とします。データ集合を
とします。
最新の研究動向
実装(数式理解を補助するためのコード例)
import pymc as pm with pm.Model() as model: # データ x, y が与えられているとする μ = pm.BART("μ", X=x, Y=y, m=50) # 50 本の木の BART 回帰 σ = pm.HalfNormal("σ", 1.0) y_obs = pm.Normal("y_obs", mu=μ, sigma=σ, observed=y) trace = pm.sample(1000, chains=4)
import numpy as np import math import random import copy # 簡単なデータ例 X = np.array([1, 2, 3, 8, 9, 10]) y = np.array([5.1, 4.9, 5.0, 8.0, 7.9, 8.1]) alpha, beta = 0.9, 2.0 mu0, tau2 = 0.0, 10.0 sigma2 = 0.5**2 class Node: def __init__(self, idx): self.idx = idx self.split_var = None self.split_val = None self.left = None self.right = None self.is_leaf = True self.mu = 0.0 root = Node(np.arange(len(X))) def depth(node, current=0): return current def split_candidates_for(node): idx = node.idx if len(idx) <= 1: return [] x_vals = sorted(X[idx]) candidates = [] for i in range(len(x_vals)-1): c = 0.5 * (x_vals[i] + x_vals[i+1]) if np.any(X[idx] <= c) and np.any(X[idx] > c): candidates.append(c) return candidates def log_marginal_likelihood(node): idx = node.idx n_node = len(idx) if n_node == 0: return 0.0 y_node = y[idx] ybar = np.mean(y_node) S_node = np.sum((y_node - ybar)**2) return -0.5 * ((ybar - mu0)**2 / (sigma2/n_node + tau2) + S_node/sigma2) def log_prior(node): d = depth(node) if node.is_leaf: return math.log(1 - alpha*(1+d)**(-beta)) else: return math.log(alpha*(1+d)**(-beta)) - math.log(len(split_candidates_for(node))) current_root = root samples = [] for t in range(1000): leaves = [current_root] if current_root.is_leaf else [current_root.left, current_root.right] if random.random() < 0.5: node = random.choice(leaves) candidates = split_candidates_for(node) if not candidates: continue c = random.choice(candidates) new_left_idx = node.idx[X[node.idx] <= c] new_right_idx = node.idx[X[node.idx] > c] if len(new_left_idx)==0 or len(new_right_idx)==0: continue new_node = Node(node.idx) new_node.is_leaf = False new_node.split_var = 0 new_node.split_val = c new_node.left = Node(new_left_idx) new_node.right = Node(new_right_idx) new_node.left.mu = np.mean(y[new_left_idx]) new_node.right.mu = np.mean(y[new_right_idx]) logp_current = log_prior(node) + log_marginal_likelihood(node) logp_new = (log_prior(new_node) + log_prior(new_node.left) + log_prior(new_node.right) + log_marginal_likelihood(new_node.left) + log_marginal_likelihood(new_node.right)) if math.log(random.random()) < (logp_new - logp_current): current_root = new_node else: if not current_root.is_leaf: left = current_root.left right = current_root.right logp_current = (log_prior(current_root) + log_prior(left) + log_prior(right) + log_marginal_likelihood(left) + log_marginal_likelihood(right)) pruned = Node(current_root.idx) logp_new = log_prior(pruned) + log_marginal_likelihood(pruned) if math.log(random.random()) < (logp_new - logp_current): current_root = pruned samples.append(copy.deepcopy(current_root)) split_points = [s.split_val for s in samples if not s.is_leaf] print("サンプルされた分割位置の平均:", np.mean(split_points))
# 新規入力に対する予測分布の近似 x_new = 7.5 preds = [] for tree in samples[500:]: node = tree while not node.is_leaf: if x_new <= node.split_val: node = node.left else: node = node.right preds.append(node.mu) print("予測平均:", np.mean(preds), "予測標準偏差:", np.std(preds))
BART の理論的性質の概略証明