O que há no XGBoost e o que o Go tem a ver com isso?

No mundo do aprendizado de máquina, um dos tipos mais populares de modelos é a árvore e os conjuntos decisivos baseados neles. As vantagens das árvores são: facilidade de interpretação, não há restrições quanto ao tipo de dependência inicial, requisitos mínimos para o tamanho da amostra. As árvores também têm uma falha importante - a tendência de reciclagem. Portanto, quase sempre as árvores são combinadas em conjuntos: floresta aleatória, aumento de gradiente, etc. Tarefas teóricas e práticas complexas são compor árvores e combiná-las em conjuntos.

No mesmo artigo, consideraremos o procedimento para gerar previsões a partir de um modelo de conjunto de árvores já treinado, recursos de implementação nas populares XGBoost aumento de gradiente XGBoost e LightGBM . Além disso, o leitor se familiarizará com a biblioteca de leaves do Go, que permite fazer previsões para conjuntos de árvores sem usar a API C das bibliotecas originais.

De onde crescem as árvores?


Considere primeiro as disposições gerais. Eles geralmente trabalham com árvores, onde:

  1. partição em um nó ocorre de acordo com um recurso
  2. árvore binária - cada nó tem um descendente esquerdo e direito
  3. no caso de um atributo material, a regra de decisão consiste em comparar o valor do atributo com um valor limite

Tirei esta ilustração da documentação do XGBoost



Nesta árvore, temos 2 nós, 2 regras de decisão e 3 folhas. Sob os círculos, os valores são indicados - o resultado da aplicação da árvore em algum objeto. Geralmente, uma função de transformação é aplicada ao resultado da computação de uma árvore ou conjunto de árvores. Por exemplo, um sigmóide para um problema de classificação binária.

Para obter previsões do conjunto de árvores obtidas pelo aumento de gradiente, é necessário adicionar os resultados das previsões de todas as árvores:

 double pred = 0.0; for (auto& tree: trees) pred += tree->Predict(feature_values); 

A seguir, haverá C++ , como é nessa linguagem que o XGBoost e o LightGBM são escritos. Omitirei detalhes irrelevantes e tentarei fornecer o código mais conciso.

Em seguida, considere o que está oculto no Predict e como a estrutura de dados da árvore está estruturada.

Árvores do XGBoost


XGBoost possui várias classes (no sentido de POO) de árvores. Falaremos sobre o RegTree (consulte include/xgboost/tree_model.h ), que, de acordo com a documentação, é o principal. Se você deixar apenas os detalhes importantes para as previsões, os membros da classe parecerão o mais simples possível:

 class RegTree { // vector of nodes std::vector<Node> nodes_; }; 

A regra de GetNext é implementada na função GetNext . O código é ligeiramente modificado, sem afetar o resultado dos cálculos:

 // get next position of the tree given current pid int RegTree::GetNext(int pid, float fvalue, bool is_unknown) const { const auto& node = nodes_[pid] float split_value = node.info_.split_cond; if (is_unknown) { return node.DefaultLeft() ? node.cleft_ : node.cright_; } else { if (fvalue < split_value) { return node.cleft_; } else { return node.cright_; } } } 

Duas coisas seguem daqui:

  1. RegTree funciona apenas com atributos reais (tipo float )
  2. valores de característica ignorados são suportados

A peça central é a classe Node . Ele contém a estrutura local da árvore, a regra de decisão e o valor da planilha:

 class Node { public: // feature index of split condition unsigned SplitIndex() const { return sindex_ & ((1U << 31) - 1U); } // when feature is unknown, whether goes to left child bool DefaultLeft() const { return (sindex_ >> 31) != 0; } // whether current node is leaf node bool IsLeaf() const { return cleft_ == -1; } private: // in leaf node, we have weights, in non-leaf nodes, we have split condition union Info { float leaf_value; float split_cond; } info_; // pointer to left, right int cleft_, cright_; // split feature index, left split or right split depends on the highest bit unsigned sindex_{0}; }; 

Os seguintes recursos podem ser distinguidos:

  1. folhas são representadas como nós para os quais cleft_ = -1
  2. o campo info_ representado como union , ou seja, dois tipos de dados (nesse caso, o mesmo) compartilham uma parte da memória, dependendo do tipo de nó
  3. o bit mais significativo em sindex_ é responsável por onde o objeto cujo valor de atributo é ignorado

Para poder rastrear o caminho, chamando o método RegTree::Predict para receber a resposta, darei as duas funções ausentes:

 float RegTree::Predict(const RegTree::FVec& feat, unsigned root_id) const { int pid = this->GetLeafIndex(feat, root_id); return nodes_[pid].leaf_value; } int RegTree::GetLeafIndex(const RegTree::FVec& feat, unsigned root_id) const { auto pid = static_cast<int>(root_id); while (!nodes_[pid].IsLeaf()) { unsigned split_index = nodes_[pid].SplitIndex(); pid = this->GetNext(pid, feat.Fvalue(split_index), feat.IsMissing(split_index)); } return pid; } 

Na função GetLeafIndex descemos os nós da árvore em um loop até GetLeafIndex a folha.

Árvores LightGBM


O LightGBM não possui uma estrutura de dados para o nó. Em vez disso, a estrutura de dados da árvore em Tree (arquivo include/LightGBM/tree.h ) contém matrizes de valores, em que o número do nó é usado como um índice. Valores em folhas também são armazenados em matrizes separadas.

 class Tree { // Number of current leaves int num_leaves_; // A non-leaf node's left child std::vector<int> left_child_; // A non-leaf node's right child std::vector<int> right_child_; // A non-leaf node's split feature, the original index std::vector<int> split_feature_; //A non-leaf node's split threshold in feature value std::vector<double> threshold_; std::vector<int> cat_boundaries_; std::vector<uint32_t> cat_threshold_; // Store the information for categorical feature handle and mising value handle. std::vector<int8_t> decision_type_; // Output of leaves std::vector<double> leaf_value_; }; 

LightGBM suporta recursos categóricos. O suporte é fornecido usando um campo de bit, armazenado em cat_threshold_ para todos os nós. Em cat_boundaries_ armazena em qual nó qual parte do campo de bit corresponde. O campo threshold_ para o caso categórico é convertido em int e corresponde ao índice em cat_boundaries_ para procurar o início do campo de bit.

Considere a regra decisiva para um atributo categórico:

 int CategoricalDecision(double fval, int node) const { uint8_t missing_type = GetMissingType(decision_type_[node]); int int_fval = static_cast<int>(fval); if (int_fval < 0) { return right_child_[node];; } else if (std::isnan(fval)) { // NaN is always in the right if (missing_type == 2) { return right_child_[node]; } int_fval = 0; } int cat_idx = static_cast<int>(threshold_[node]); if (FindInBitset(cat_threshold_.data() + cat_boundaries_[cat_idx], cat_boundaries_[cat_idx + 1] - cat_boundaries_[cat_idx], int_fval)) { return left_child_[node]; } return right_child_[node]; } 

Pode-se observar que, dependendo do tipo de missing_type valor NaN abaixa automaticamente a solução ao longo do ramo direito da árvore. Caso contrário, o NaN será substituído por 0. Procurar um valor em um campo de bits é bastante simples:

 bool FindInBitset(const uint32_t* bits, int n, int pos) { int i1 = pos / 32; if (i1 >= n) { return false; } int i2 = pos % 32; return (bits[i1] >> i2) & 1; } 

ou seja, por exemplo, para o atributo categórico int_fval=42 verificado se o 41º bit (numeração de 0) está definido na matriz.

Essa abordagem tem uma desvantagem significativa: se um atributo categórico pode aceitar valores grandes, por exemplo 100500, para cada regra de decisão para esse atributo, um campo de bit será criado com tamanho de até 12564 bytes!

Portanto, é desejável renumerar os valores dos atributos categóricos para que eles passem continuamente de 0 ao valor máximo .

Pela minha parte, fiz alterações explicativas no LightGBM e as aceitei .

Lidar com atributos físicos não é muito diferente do XGBoost , e vou pular isso por uma questão de brevidade.

leaves - biblioteca para previsões em Go


XGBoost e LightGBM bibliotecas muito poderosas para a construção de modelos de LightGBM gradiente em árvores de decisão. Para usá-los em um serviço de back-end, onde são necessários algoritmos de aprendizado de máquina, é necessário resolver as seguintes tarefas:

  1. Treinamento periódico de modelos offline
  2. Entrega de modelos no serviço de back-end
  3. Modelos de pesquisa online

Para escrever um serviço de back-end carregado, o Go é um idioma popular. XGBoost ou LightGBM pela API C e cgo não é a solução mais fácil - a compilação do programa é complicada, devido ao manuseio descuidado, você pode pegar o SIGTERM , problemas com o número de threads do sistema (OpenMP dentro de bibliotecas versus threads de tempo de execução).

Por isso, decidi escrever uma biblioteca em puro Go para previsões usando modelos criados no XGBoost ou no LightGBM . É chamado de leaves .

folhas

Principais recursos da biblioteca:

  • Para modelos LightGBM
    • Lendo modelos de um formato padrão (texto)
    • Suporte para atributos físicos e categóricos
    • Suporte a valores ausentes
    • Otimização do trabalho com variáveis ​​categóricas
    • Otimização de previsão com estruturas de dados somente de previsão

  • Para modelos XGBoost
    • Lendo modelos de um formato padrão (binário)
    • Suporte a valores ausentes
    • Otimização de previsão


Aqui está um programa Go mínimo que carrega um modelo do disco e exibe uma previsão:

 package main import ( "bufio" "fmt" "os" "github.com/dmitryikh/leaves" ) func main() { // 1.     path := "lightgbm_model.txt" reader, err := os.Open(path) if err != nil { panic(err) } defer reader.Close() // 2.   LightGBM model, err := leaves.LGEnsembleFromReader(bufio.NewReader(reader)) if err != nil { panic(err) } // 3.  ! fvals := []float64{1.0, 2.0, 3.0} p := model.Predict(fvals, 0) fmt.Printf("Prediction for %v: %f\n", fvals, p) } 

A API da biblioteca é mínima. Para usar o modelo XGBoost basta chamar o método leaves.XGEnsembleFromReader vez do método acima. As previsões podem ser feitas em lotes chamando os model.PredictCSR ou model.PredictCSR . Mais cenários de uso podem ser encontrados nos testes de folhas .

Apesar do Go mais lento que o C++ (principalmente devido a verificações mais pesadas de tempo de execução e tempo de execução), graças a várias otimizações, foi possível obter uma taxa de previsão comparável à chamada da API C das bibliotecas originais.


Mais detalhes sobre os resultados e o método de comparação estão no repositório no github .

Veja a raiz


Espero que este artigo abra as portas para a implementação de árvores nas LightGBM e LightGBM . Como você pode ver, as construções básicas são bastante simples, e eu encorajo os leitores a aproveitar o código-fonte aberto - a estudar o código quando houver dúvidas sobre como ele funciona.

Para aqueles interessados ​​no tópico de uso de modelos de aumento de gradiente em seus serviços no idioma Go, recomendo que você se familiarize com a biblioteca de folhas . Usando leaves você pode facilmente usar soluções de ponta no aprendizado de máquina em seu ambiente de produção, quase sem perder velocidade em comparação com as implementações C ++ originais.

Boa sorte

Source: https://habr.com/ru/post/pt423495/


All Articles