決定木の理論とフルスクラッチ実装とその解説

最初に

決定木の理論とフルスクラッチ実装とその解説というと、既に使い古された話題の様に感じてしまいますが、今回の記事から派生して、ランダムフォレスト、GBDT、XGboost(LightGBMは扱わないつもり)、因果木、因果フォレスト、ランダムフォレスト-learnerの理論とできる部分はフルスクラッチ実装、めんどくさいものは、理論と解説に抑えて扱っていこうと考えており、そのまず初めとして、決定木自体の理論に触れないことは、できないなと思い、決定木の記事を書こうと思った次第です。(めんどくさくなって書かないパターンも全然あり得るのでご了承ください)他の記事との差別化は、数式を含めた解説と、フルスクラッチ実装のコードと数式を絡めた解説みたいな感じで、初心者に超優しい解説記事みたいな感じで仕上げて見せると、書き始めは思っております。書いていくうちに、初心者に超優しくないじゃんみたいなことになってしまう気もしなくはないですが、データサイエンスを学ぶ上で、理論をしっかり理解するというのはとても重要なことなので、理論に重きを置き、フルスクラッチ実装は理論を解釈する上での補助ツール的な感じで、進めていこうと思います。
はてなブログで数式(普通のTeXではなくはてなTexという曲者)を扱うと、レイアウトが思うようにいかないことが多いので、読みづらいと思ったら、はてなブログのせいにしてください。あと図を本当は入れた方がいいのでしょうが、作成するのがめんどくさいので、恐らく作らないと思います。

決定木とは

決定木

1.決定木の基本構造

決定木の基本構造は、ルートノード、内部ノード、葉ノードから構成されます。各ノードは特徴量に基づいて分割を行い、最終的に葉ノードで予測を行います。

2.不純度

次に、不純度というものの解説に入ります。
主な不純度の指標には以下があります。不純度指標は、決定木の各ノードでの分割の品質を評価し、最適な分割を選択するために使用されます。不純度が低いほど、そのノードでのデータの分類や予測がより確実であることを示します。

ジニ不純度(分類問題)

Gini(T) = 1 - \sum_{i=1}^{c} (p_i)^2
T:ノード(データ集合)
ここで、p_iはクラスiの割合です。 ジニ不純度は、ランダムに選ばれた要素が誤って分類される確率を表します。値が0に近いほど、ノードの純度が高いことを意味します。完全に純粋なノード(単一クラスのみを含む)の場合、ジニ不純度は0となります。最大値は (1 - \frac{1}{c}) で、すべてのクラスが均等に分布している場合に達します。

  
エントロピー(分類問題)

H(T) = -\sum_{i=1}^{c} p_i \log_2(p_i)
エントロピーは、ノード内のデータの無秩序さや不確実性を測る指標です。値が低いほど、ノードの純度が高いことを意味します。完全に純粋なノードの場合、エントロピーは0になります。最大値は \log_2(c)で、すべてのクラスが均等に分布している場合に達します。エントロピーは、情報理論に基づいており、データを符号化するのに必要な最小ビット数と関連しています。

  
分散(回帰問題)

Var(T) = \frac{1}{N} \sum_{i=1}^{N} (y_i - \mu)^2
ここで、Nはサンプル数、y_iは各サンプルの目的変数、\muは平均値です。 良く知られている通り分散は、データポイントが平均値からどれだけばらついているかを示す指標です。回帰問題では、ノード内のデータポイントの目的変数がどれだけ集中しているかを測ります。分散が小さいほど、ノード内のデータポイントの値が近く、より純粋であることを意味します。完全に純粋なノード(すべての y_i が同じ値)の場合、分散は0になります。分散が大きいほど、ノード内のデータの不確実性が高いことを示します。

3. 情報利得

情報利得は、特定の特徴量による分割が、どれだけデータの不純度を減少させたかを測る指標です。言い換えれば、その分割によってどれだけ情報が得られたかを表します。情報利得が高いほど、その分割がデータをより良く分類できることを意味します。
ノードの分割は情報利得を最大化するように行われます。情報利得は以下のように計算されます。

IG(T, a) = I(T) - \sum_{v} \frac{|T_v|}{|T|} I(T_v)
ここで、I(T) は親ノードの不純度、T_v は特徴量 a で分割された子ノード、|T| はサンプル数です。 この式は、親ノードの不純度から、分割後の子ノードの重み付き平均不純度を引いたものです。つまり、分割前後での不純度の減少量を表しています。\frac{|T_v|}{|T|} は各子ノードの重みを表し、サンプル全体に対する各子ノードのサンプル数の割合です。 ということは、情報利得が大きいほど、その分割がデータセットの構造をより良く捉えていることを示します。決定木アルゴリズムは、各ステップで最大の情報利得を持つ特徴量と分割点を選択することで、効率的にデータを分類または回帰する木構造を構築します。 そして、このプロセスは再帰的に適用され、各ノードで最適な分割が選択されていきます。これにより、決定木は複雑なデータ構造を階層的に表現し、解釈可能なモデルを生成することができるというわけです。

注意して欲しいこと

ここで決定木は複雑なデータ構造を階層的に表現し、解釈可能なモデルを生成するという言葉をわざわざ入れた理由は、線形回帰と同様に、解釈性に優れるため、アナリティクス(分析)の手段として使用されることが、現在もあるためです。ですがここで注意して欲しいのは、線形回帰同様に、分類や予測タスクで使用しない場合(データの解釈のために使用する場合)、要は決定木モデルの構造に注目する場合は、多重共線性に注意する必要が出てくるという点です。なぜならば、変数間で相関係数が非常に高い変数が複数ある場合(同じような情報を持つ変数が複数ある場合)は、分岐条件に表れる変数というのが、相関が高い変数間で奪い合われ、正しい構造を表さない危険性があるためです。線形回帰で例えると、多重共線性が発生した場合、各係数の推定値が不安定になり、正しい推定が困難になりますが、このようなことが決定木でも発生するということです。予測結果や分類結果に注目するのであれば、多重共線性を意識する必要はありませんが、解釈するツールとして使用する場合は、多重共線性の問題を意識せざる負えないのは、決定木系の手法全てに共通することです。それはランダムフォレストでもGBDT系でも、それをshapで見るんでも、多重共線性を意識しなくてはなりません。

4. 最適分割点の選択

各特徴量について、可能なすべての分割点で情報利得を計算し、最大の情報利得を持つ分割を選択します。

5. 再帰的な木の成長

選択された最適分割に基づいてノードを分割し、子ノードに対して同じプロセスを再帰的に適用します。

6. 停止条件

以下のいずれかの条件を満たすまで木を成長させます:

  • 最大深さに達した
  • ノードのサンプル数が最小値を下回った
  • 情報利得が閾値を下回った

7. 枝刈り(オプション)

枝刈りは、決定木の過学習を防ぐための重要な技術です。完全に成長した決定木は、トレーニングデータに対して過度に適合し、新しいデータに対する汎化性能が低下する可能性があります。枝刈りは、この問題に対処するために使用されます。
コスト複雑性枝刈り(Cost-Complexity Pruning)は、最も一般的な枝刈り手法の1つです。この方法は、木の複雑さと予測誤差のバランスを取ることを目的としています。

コスト複雑性パラメータ \alpha は以下のように定義されます:
\alpha(T) = \frac{R(t) - R(T)}{|T| - 1}

ここで、

- R(t) は枝刈り前の誤差(元の部分木の誤差) - R(T) は枝刈り後の誤差(葉ノードに置き換えた後の誤差) - |T| は葉ノードの数です

この式の概念的な意味は以下の通りです。

1. 分子 (R(t) - R(T)) は、枝刈りによる誤差の増加を表します。 2. 分母 (|T| - 1) は、枝刈りによる木の複雑さの減少(削除された葉ノードの数)を表します。 3. \alpha は、木の複雑さの単位減少あたりの誤差増加率を表します。

フルスクラッチ実装

決定木のフルスクラッチ実装をやっていこうと思います。各部分をきちんと説明しようかと考えていましたが、このままいくとしないかもです...

import numpy as np
from collections import Counter

class DecisionTreeClassifier:
    def __init__(self, max_depth=None, min_samples_split=2):
        # max_depth: 木の最大深さ(過学習防止のため)
        # min_samples_split: ノードを分割するための最小サンプル数
        self.max_depth = max_depth
        self.min_samples_split = min_samples_split
        self.tree = None

    class Node:
        # 決定木のノードを表すクラス
        def __init__(self, feature=None, threshold=None, left=None, right=None, value=None):
            self.feature = feature  # 分割に使用する特徴量のインデックス
            self.threshold = threshold  # 分割の閾値
            self.left = left  # 左の子ノード
            self.right = right  # 右の子ノード
            self.value = value  # 葉ノードの場合の予測値

    def fit(self, X, y):
        # トレーニングデータを用いて決定木を構築
        self.n_features = X.shape[1]
        self.tree = self._grow_tree(X, y)

    def _grow_tree(self, X, y, depth=0):
        n_samples, n_features = X.shape
        n_classes = len(np.unique(y))

        # 停止条件をチェック
        if (self.max_depth is not None and depth >= self.max_depth) or \
           n_samples < self.min_samples_split or \
           n_classes == 1:
            leaf_value = self._most_common_label(y)
            return self.Node(value=leaf_value)

        # 最適な分割を見つける
        best_feature, best_threshold = self._best_split(X, y)

        # データを分割
        left_idxs = X[:, best_feature] < best_threshold
        right_idxs = ~left_idxs

        # 子ノードを再帰的に構築
        left_subtree = self._grow_tree(X[left_idxs], y[left_idxs], depth + 1)
        right_subtree = self._grow_tree(X[right_idxs], y[right_idxs], depth + 1)

        return self.Node(best_feature, best_threshold, left_subtree, right_subtree)

    def _best_split(self, X, y):
        # 最適な分割を見つけるメソッド
        best_gain = -1
        best_feature, best_threshold = None, None

        for feature in range(self.n_features):
            thresholds = np.unique(X[:, feature])
            for threshold in thresholds:
                # 情報利得を計算
                # IG(T, a) = I(T) - Σ((|T_v| / |T|) * I(T_v))
                gain = self._information_gain(X[:, feature], y, threshold)
                if gain > best_gain:
                    best_gain = gain
                    best_feature = feature
                    best_threshold = threshold

        return best_feature, best_threshold

    def _information_gain(self, X_column, y, threshold):
        # 情報利得を計算するメソッド
        parent_entropy = self._entropy(y)

        left_idxs = X_column < threshold
        right_idxs = ~left_idxs

        if len(y[left_idxs]) == 0 or len(y[right_idxs]) == 0:
            return 0

        n = len(y)
        n_left, n_right = len(y[left_idxs]), len(y[right_idxs])
        e_left, e_right = self._entropy(y[left_idxs]), self._entropy(y[right_idxs])
        child_entropy = (n_left / n) * e_left + (n_right / n) * e_right

        # 情報利得 = 親ノードのエントロピー - 子ノードの重み付きエントロピーの和
        return parent_entropy - child_entropy

    def _entropy(self, y):
        # エントロピーを計算するメソッド
        # H(T) = -Σ(p_i * log2(p_i))
        hist = np.bincount(y)
        ps = hist / len(y)
        return -np.sum([p * np.log2(p) for p in ps if p > 0])

    def _most_common_label(self, y):
        # 最も頻繁に出現するラベルを返すメソッド
        counter = Counter(y)
        return counter.most_common(1)[0][0]

    def predict(self, X):
        # 新しいデータポイントに対して予測を行うメソッド
        return np.array([self._traverse_tree(x, self.tree) for x in X])

    def _traverse_tree(self, x, node):
        # 決定木を走査して予測を行うメソッド
        if node.value is not None:
            return node.value

        if x[node.feature] < node.threshold:
            return self._traverse_tree(x, node.left)
        return self._traverse_tree(x, node.right)

このDecisionTreeClassifierクラスを使用して、以下の様にモデル構築し、予測を行うことができます。

# モデルのインスタンス化
clf = DecisionTreeClassifier(max_depth=5, min_samples_split=2)

# モデルの学習
clf.fit(X_train, y_train)

# 予測
y_pred = clf.predict(X_test)

# 精度の評価(例:正解率の計算)
accuracy = np.mean(y_pred == y_test)
print(f"Accuracy: {accuracy:.4f}")
  1. エントロピー_entropyメソッド)
  2. 情報利得(_information_gainメソッド)
  3. 最適分割の選択(_best_splitメソッド)
  4. 再帰的な木の成長(_grow_tree メソッド)
  5. 停止条件:
  6. 予測(predict_traverse_tree メソッド)
      
    今回の実装では、基本的な決定木分類器で実装しているため、枝刈りなどの機能は含んでいません。

    最後に

    結局各部分の詳細な説明はせずに終わってしまいましたが、決定木の理論とそのフルスクラッチ実装はできたかなと思います。 フルスクラッチ実装をなぜ入れたかというと、初心者の方は、数式を読んでも理解できない部分が多いと思ったため、コードと数式を見比べながら理解していくのがベストだと思ったためです。
    あと本当にこの記事で言いたかったことは一つで、注意の部分です。この部分がどうも実際に分析している人で意識できていない人が多く、誤った示唆を提供している場合を多く目撃したためです。
    初心者に優しい記事になったかどうかは、わかりませんが理解の一助となれば幸いです。