XGBoost内部有什么,Go与它有什么关系?

在机器学习领域,最流行的模型类型之一是决定性的树和基于它们的合奏。 树木的优点是:易于解释,对初始依赖性的类型没有限制,对样本大小的软要求。 树木也有一个主要缺陷-易于重新训练。 因此,几乎总是将树木组合成整体:随机森林,梯度增强等。复杂的理论和实践任务是将树木组成并将其组合成整体。

在同一篇文章中,我们将考虑从已经训练XGBoost树集成模型生成预测的过程,以及流行的梯度提升XGBoostLightGBM中的实现功能。 此外,读者还将熟悉Go的leaves库,这使您无需使用原始库的C API即可预测树木的整体。

树木从哪里生长?


首先考虑一般规定。 它们通常与树木配合使用,其中:

  1. 节点中的分区根据一种功能发生
  2. 二叉树-每个节点都有一个左右后代
  3. 对于重要属性,决策规则包括将属性值与阈值进行比较

我从XGBoost文档中获取了此插图



在这棵树中,我们有2个节点,2个决策规则和3个工作表。 在圆圈下面,指示值-将树应用于某个对象的结果。 通常,将变换函数应用于计算树或树集合的结果。 例如,针对二进制分类问题的S形

要从通过梯度增强获得的树木集合中获得预测,您需要添加所有树木的预测结果:

 double pred = 0.0; for (auto& tree: trees) pred += tree->Predict(feature_values); 

在下文中,将有C++ ,如 XGBoostLightGBM就是用这种语言编写的。 我将忽略无关的细节,并尝试给出最简洁的代码。

接下来,考虑Predict隐藏的内容以及树的数据结构的结构。

XGBoost树


XGBoost具有几种树(在OOP的意义上)。 我们将讨论RegTree (请参阅include/xgboost/tree_model.h ),根据文档,这是主要的。 如果仅保留对预测很重要的细节,则该类的成员看起来将尽可能简单:

 class RegTree { // vector of nodes std::vector<Node> nodes_; }; 

GetNext规则在GetNext函数中实现。 对该代码进行了少许修改,而不会影响计算结果:

 // get next position of the tree given current pid int RegTree::GetNext(int pid, float fvalue, bool is_unknown) const { const auto& node = nodes_[pid] float split_value = node.info_.split_cond; if (is_unknown) { return node.DefaultLeft() ? node.cleft_ : node.cright_; } else { if (fvalue < split_value) { return node.cleft_; } else { return node.cright_; } } } 

这里有两件事:

  1. RegTree仅适用于真实属性(类型float
  2. 支持跳过的特征值

核心是Node类。 它包含树的局部结构,决策规则和工作表的值:

 class Node { public: // feature index of split condition unsigned SplitIndex() const { return sindex_ & ((1U << 31) - 1U); } // when feature is unknown, whether goes to left child bool DefaultLeft() const { return (sindex_ >> 31) != 0; } // whether current node is leaf node bool IsLeaf() const { return cleft_ == -1; } private: // in leaf node, we have weights, in non-leaf nodes, we have split condition union Info { float leaf_value; float split_cond; } info_; // pointer to left, right int cleft_, cright_; // split feature index, left split or right split depends on the highest bit unsigned sindex_{0}; }; 

可以区分以下功能:

  1. 工作表表示为cleft_ = -1节点
  2. info_字段表示为union ,即 两种类型的数据(在这种情况下相同)共享一个内存,具体取决于节点的类型
  3. sindex_的最高有效位负责跳过其属性值的对象的位置

为了能够跟踪从调用RegTree::Predict方法到接收答案的路径,我将提供缺失的两个函数:

 float RegTree::Predict(const RegTree::FVec& feat, unsigned root_id) const { int pid = this->GetLeafIndex(feat, root_id); return nodes_[pid].leaf_value; } int RegTree::GetLeafIndex(const RegTree::FVec& feat, unsigned root_id) const { auto pid = static_cast<int>(root_id); while (!nodes_[pid].IsLeaf()) { unsigned split_index = nodes_[pid].SplitIndex(); pid = this->GetNext(pid, feat.Fvalue(split_index), feat.IsMissing(split_index)); } return pid; } 

GetLeafIndex函数中GetLeafIndex我们循环遍历树节点,直到到达叶子为止。

LightGBM树


LightGBM没有该节点的数据结构。 相反,Tree Tree的数据结构( include/LightGBM/tree.h文件)包含值的数组,其中节点号用作索引。 叶子中的值也存储在单独的数组中。

 class Tree { // Number of current leaves int num_leaves_; // A non-leaf node's left child std::vector<int> left_child_; // A non-leaf node's right child std::vector<int> right_child_; // A non-leaf node's split feature, the original index std::vector<int> split_feature_; //A non-leaf node's split threshold in feature value std::vector<double> threshold_; std::vector<int> cat_boundaries_; std::vector<uint32_t> cat_threshold_; // Store the information for categorical feature handle and mising value handle. std::vector<int8_t> decision_type_; // Output of leaves std::vector<double> leaf_value_; }; 

LightGBM支持分类功能。 使用位字段提供支持,该位字段存储在cat_threshold_用于所有节点。 在cat_boundaries_存储位字段的哪一部分对应于哪个节点。 用于分类情况的threshold_字段将转换为int并与cat_boundaries_的索引相对应,以搜索位字段的开头。

考虑分类属性的决定性规则:

 int CategoricalDecision(double fval, int node) const { uint8_t missing_type = GetMissingType(decision_type_[node]); int int_fval = static_cast<int>(fval); if (int_fval < 0) { return right_child_[node];; } else if (std::isnan(fval)) { // NaN is always in the right if (missing_type == 2) { return right_child_[node]; } int_fval = 0; } int cat_idx = static_cast<int>(threshold_[node]); if (FindInBitset(cat_threshold_.data() + cat_boundaries_[cat_idx], cat_boundaries_[cat_idx + 1] - cat_boundaries_[cat_idx], int_fval)) { return left_child_[node]; } return right_child_[node]; } 

可以看到,根据missing_typeNaN自动降低沿树的右分支的解。 否则,将NaN替换为0。在位字段中搜索值非常简单:

 bool FindInBitset(const uint32_t* bits, int n, int pos) { int i1 = pos / 32; if (i1 >= n) { return false; } int i2 = pos % 32; return (bits[i1] >> i2) & 1; } 

即,例如,对于分类属性int_fval=42检查数组中是否设置了第41位(从0开始编号)。

这种方法有一个明显的缺点:如果分类属性可以采用较大的值,例如100500,则对于该属性的每个决策规则,将创建一个最大为12564字节的位字段!

因此,希望对分类属性的值进行重新编号,以使它们从0连续变为最大值

就我而言,我对LightGBM进行了说明性更改并接受了它们

处理物理属性与XGBoost并没有太大区别,为简洁起见,我将跳过此内容。

leaves-Go中的预测库


XGBoostLightGBM非常强大的库,用于在决策树上构建梯度LightGBM模型。 要在需要机器学习算法的后端服务中使用它们,必须解决以下任务:

  1. 离线定期训练模型
  2. 后端服务中的模型交付
  3. 在线投票模型

对于编写加载的后端服务, Go是一种流行的语言。 通过C API和cgo XGBoostLightGBM并不是最简单的解决方案-程序的构建很复杂,由于粗心的处理,您可能会发现SIGTERM ,这是系统线程数的问题(库中的OpenMP与运行时线程)。

因此,我决定使用XGBoostLightGBM内置的模型在纯Go上编写一个库,以进行预测。 它被称为leaves

叶

该库的主要功能:

  • 对于LightGBM型号
    • 从标准格式(文本)读取模型
    • 支持物理和分类属性
    • 缺失价值支持
    • 使用分类变量优化工作
    • 仅具有预测数据结构的预测优化

  • 对于XGBoost模型
    • 从标准格式(二进制)读取模型
    • 缺失价值支持
    • 预测优化


这是一个最小的Go程序,该程序从磁盘加载模型并显示预测:

 package main import ( "bufio" "fmt" "os" "github.com/dmitryikh/leaves" ) func main() { // 1.     path := "lightgbm_model.txt" reader, err := os.Open(path) if err != nil { panic(err) } defer reader.Close() // 2.   LightGBM model, err := leaves.LGEnsembleFromReader(bufio.NewReader(reader)) if err != nil { panic(err) } // 3.  ! fvals := []float64{1.0, 2.0, 3.0} p := model.Predict(fvals, 0) fmt.Printf("Prediction for %v: %f\n", fvals, p) } 

库API最少。 要使用XGBoost模型XGBoost只需调用leaves.XGEnsembleFromReader方法而不是上面的方法即可。 可以通过调用PredictDensemodel.PredictCSR来批量进行预测。 在叶子测试中可以找到更多的使用场景。

尽管Go运行速度比C++慢(主要是由于运行时和运行时检查工作量大),但由于进行了许多优化,所以可以达到与调用原始库的C API相当的预测速度。


有关结果和比较方法的更多详细信息,请参见github上存储库

见根


我希望本文为XGBoostLightGBM的树实现打开大门。 如您所见,基本结构非常简单,我鼓励读者利用开放源代码-在对代码的工作方式有疑问时研究代码。

对于那些对使用Go语言在其服务中使用梯度增强模型感兴趣的人,我建议您熟悉叶子库。 使用leaves您可以在生产环境中的机器学习中轻松使用领先的解决方案,而与原始C ++实现相比,几乎不会损失速度。

祝你好运!

Source: https://habr.com/ru/post/zh-CN423495/


All Articles