这是split的代码:
def split(X, y, d, value):
index_a = (X[:,d] <= value)
index_b = (X[:,d] > value)
return X[index_a], X[index_b], y[index_a], y[index_b]
关于split的理解:
假设y里面有五个种类,我们做的就是把这5个种类分成1个种类和4个种类的前提下,尽量使这个“1”纯净。
这是try_split的代码:
def try_split(X, y):
best_entropy = float('inf')
best_d, best_v = -1, -1
for d in range(X.shape[1]):
sorted_index = np.argsort(X[:,d])
for i in range(1, len(X)):
if X[sorted_index[i], d] != X[sorted_index[i-1], d]:
v = (X[sorted_index[i], d] + X[sorted_index[i-1], d])/2
X_l, X_r, y_l, y_r = split(X, y, d, v)
e = entropy(y_l) + entropy(y_r)
if e < best_entropy:
best_entropy, best_d, best_v = e, d, v
return best_entropy, best_d, best_v
关于try_split的理解:
简单来说,就假设y里只有2种元素。在try_split过程中,进行索引排序后,如果对应的y确实是按照由弱到强,或者由强到弱分布的,那找到的这个v确实是合适的,可假如经过索引排序后,对应的y是由强到弱再到强分布的,那用这个方法找到的v其实是不合适的,也就是说这个方法并不合适。
不知道这两个理解对不对?
登录后可查看更多问答,登录/注册