Was ist in XGBoost enthalten und was hat Go damit zu tun?

In der Welt des maschinellen Lernens ist einer der beliebtesten Modelltypen der entscheidende Baum und die darauf basierenden Ensembles. Die Vorteile von Bäumen sind: einfache Interpretation, es gibt keine Einschränkungen hinsichtlich der Art der anfänglichen Abhängigkeit, weiche Anforderungen an die Größe der Stichprobe. Bäume haben auch einen großen Fehler - die Tendenz zur Umschulung. Daher werden Bäume fast immer zu Ensembles zusammengefasst: zufälliger Wald, Gradientenverstärkung usw. Komplexe theoretische und praktische Aufgaben bestehen darin, Bäume zusammenzusetzen und zu Ensembles zu kombinieren.

Im selben Artikel werden wir das Verfahren zum Generieren von Vorhersagen aus einem bereits trainierten XGBoost Modell sowie Implementierungsfunktionen in den beliebten Gradienten-Boosting- XGBoost und LightGBM . Außerdem wird der Leser mit der leaves Bibliothek für Go vertraut gemacht, mit der Sie Vorhersagen für Baumensembles treffen können, ohne die C-API der Originalbibliotheken zu verwenden.

Woher wachsen die Bäume?


Betrachten Sie zunächst die allgemeinen Bestimmungen. Sie arbeiten normalerweise mit Bäumen, wo:

  1. Die Partition in einem Knoten erfolgt gemäß einer Funktion
  2. Binärbaum - Jeder Knoten hat einen linken und einen rechten Nachkommen
  3. Bei einem wesentlichen Attribut besteht die Entscheidungsregel darin, den Wert des Attributs mit einem Schwellenwert zu vergleichen

Ich habe diese Abbildung aus der XGBoost-Dokumentation entnommen



In diesem Baum haben wir 2 Knoten, 2 Entscheidungsregeln und 3 Blätter. Unter den Kreisen werden die Werte angezeigt - das Ergebnis der Anwendung des Baums auf ein Objekt. Normalerweise wird eine Transformationsfunktion auf das Ergebnis der Berechnung eines Baums oder Baumensembles angewendet. Zum Beispiel ein Sigmoid für ein binäres Klassifizierungsproblem.

Um Vorhersagen aus dem Ensemble von Bäumen zu erhalten, die durch Gradientenverstärkung erhalten wurden, müssen Sie die Ergebnisse der Vorhersagen aller Bäume hinzufügen:

 double pred = 0.0; for (auto& tree: trees) pred += tree->Predict(feature_values); 

Im Folgenden wird es C++ , as In dieser Sprache sind XGBoost und LightGBM geschrieben. Ich werde irrelevante Details weglassen und versuchen, den prägnantesten Code zu geben.

Betrachten Sie als Nächstes, was in Predict verborgen ist und wie die Datenstruktur des Baums strukturiert ist.

XGBoost-Bäume


XGBoost hat mehrere Klassen (im Sinne von OOP) von Bäumen. Wir werden über RegTree sprechen (siehe include/xgboost/tree_model.h ), das laut Dokumentation das wichtigste ist. Wenn Sie nur die für Vorhersagen wichtigen Details belassen, sehen die Mitglieder der Klasse so einfach wie möglich aus:

 class RegTree { // vector of nodes std::vector<Node> nodes_; }; 

Die GetNext ist in der GetNext Funktion implementiert. Der Code wird geringfügig geändert, ohne das Ergebnis der Berechnungen zu beeinflussen:

 // get next position of the tree given current pid int RegTree::GetNext(int pid, float fvalue, bool is_unknown) const { const auto& node = nodes_[pid] float split_value = node.info_.split_cond; if (is_unknown) { return node.DefaultLeft() ? node.cleft_ : node.cright_; } else { if (fvalue < split_value) { return node.cleft_; } else { return node.cright_; } } } 

Zwei Dinge folgen von hier:

  1. RegTree funktioniert nur mit realen Attributen (Typ float )
  2. Übersprungene Kennwerte werden unterstützt

Das Herzstück ist die Node Klasse. Es enthält die lokale Struktur des Baums, die Entscheidungsregel und den Wert des Blattes:

 class Node { public: // feature index of split condition unsigned SplitIndex() const { return sindex_ & ((1U << 31) - 1U); } // when feature is unknown, whether goes to left child bool DefaultLeft() const { return (sindex_ >> 31) != 0; } // whether current node is leaf node bool IsLeaf() const { return cleft_ == -1; } private: // in leaf node, we have weights, in non-leaf nodes, we have split condition union Info { float leaf_value; float split_cond; } info_; // pointer to left, right int cleft_, cright_; // split feature index, left split or right split depends on the highest bit unsigned sindex_{0}; }; 

Folgende Merkmale können unterschieden werden:

  1. Blätter werden als Knoten dargestellt, für die cleft_ = -1
  2. das info_ Feld info_ als union , d.h. Je nach Knotentyp teilen sich zwei Datentypen (in diesem Fall derselbe) einen Speicher
  3. Das höchstwertige Bit in sindex_ ist dafür verantwortlich, wo das Objekt, dessen Attributwert übersprungen wird

Um den Pfad vom Aufrufen der RegTree::Predict Methode bis zum Empfang der Antwort verfolgen zu können, werde ich die beiden fehlenden Funktionen RegTree::Predict :

 float RegTree::Predict(const RegTree::FVec& feat, unsigned root_id) const { int pid = this->GetLeafIndex(feat, root_id); return nodes_[pid].leaf_value; } int RegTree::GetLeafIndex(const RegTree::FVec& feat, unsigned root_id) const { auto pid = static_cast<int>(root_id); while (!nodes_[pid].IsLeaf()) { unsigned split_index = nodes_[pid].SplitIndex(); pid = this->GetNext(pid, feat.Fvalue(split_index), feat.IsMissing(split_index)); } return pid; } 

In der GetLeafIndex Funktion GetLeafIndex wir die GetLeafIndex in einer Schleife GetLeafIndex bis wir das Blatt treffen.

LightGBM-Bäume


LightGBM hat keine Datenstruktur für den Knoten. Stattdessen enthält die Datenstruktur des include/LightGBM/tree.h Datei include/LightGBM/tree.h ) Arrays von Werten, wobei die Knotennummer als Index verwendet wird. Werte in Blättern werden auch in separaten Arrays gespeichert.

 class Tree { // Number of current leaves int num_leaves_; // A non-leaf node's left child std::vector<int> left_child_; // A non-leaf node's right child std::vector<int> right_child_; // A non-leaf node's split feature, the original index std::vector<int> split_feature_; //A non-leaf node's split threshold in feature value std::vector<double> threshold_; std::vector<int> cat_boundaries_; std::vector<uint32_t> cat_threshold_; // Store the information for categorical feature handle and mising value handle. std::vector<int8_t> decision_type_; // Output of leaves std::vector<double> leaf_value_; }; 

LightGBM unterstützt kategoriale Funktionen. Die Unterstützung erfolgt über ein Bitfeld, das für alle Knoten in cat_threshold_ gespeichert ist. In cat_boundaries_ welchem ​​Knoten welcher Teil des cat_boundaries_ entspricht. Das threshold_ für den kategorialen Fall wird in int konvertiert und entspricht dem Index in cat_boundaries_ , um nach dem Anfang des cat_boundaries_ zu suchen.

Betrachten Sie die entscheidende Regel für ein kategoriales Attribut:

 int CategoricalDecision(double fval, int node) const { uint8_t missing_type = GetMissingType(decision_type_[node]); int int_fval = static_cast<int>(fval); if (int_fval < 0) { return right_child_[node];; } else if (std::isnan(fval)) { // NaN is always in the right if (missing_type == 2) { return right_child_[node]; } int_fval = 0; } int cat_idx = static_cast<int>(threshold_[node]); if (FindInBitset(cat_threshold_.data() + cat_boundaries_[cat_idx], cat_boundaries_[cat_idx + 1] - cat_boundaries_[cat_idx], int_fval)) { return left_child_[node]; } return right_child_[node]; } 

Es ist ersichtlich, dass missing_type Wert NaN abhängig vom missing_type die Lösung automatisch entlang des rechten Zweigs des Baums senkt. Andernfalls wird NaN durch 0 ersetzt. Die Suche nach einem Wert in einem Bitfeld ist ganz einfach:

 bool FindInBitset(const uint32_t* bits, int n, int pos) { int i1 = pos / 32; if (i1 >= n) { return false; } int i2 = pos % 32; return (bits[i1] >> i2) & 1; } 

zum Beispiel wird für das kategoriale Attribut int_fval=42 geprüft, ob das 41. Bit (Nummerierung von 0) in dem Array gesetzt ist.

Dieser Ansatz hat einen wesentlichen Nachteil: Wenn ein kategoriales Attribut große Werte annehmen kann, z. B. 100500, wird für jede Entscheidungsregel für dieses Attribut ein Bitfeld mit einer Größe von bis zu 12564 Byte erstellt!

Daher ist es wünschenswert, die Werte von kategorialen Attributen so neu zu nummerieren, dass sie kontinuierlich von 0 auf den Maximalwert gehen .

Ich für meinen Teil habe erklärende Änderungen an LightGBM und diese akzeptiert .

Der Umgang mit physischen Attributen unterscheidet sich nicht wesentlich von XGBoost , und ich werde dies der Kürze XGBoost überspringen.

Blätter - Bibliothek für Vorhersagen in Go


XGBoost und LightGBM sehr leistungsfähige Bibliotheken zum LightGBM Gradienten- LightGBM Modellen auf Entscheidungsbäumen. Um sie in einem Backend-Dienst zu verwenden, in dem Algorithmen für maschinelles Lernen benötigt werden, müssen die folgenden Aufgaben gelöst werden:

  1. Regelmäßiges Training von Modellen offline
  2. Lieferung von Modellen im Backend-Service
  3. Modelle online abfragen

Zum Schreiben eines geladenen Backend-Dienstes ist Go eine beliebte Sprache. XGBoost oder LightGBM durch die C-API und cgo ist nicht die einfachste Lösung. Die Erstellung des Programms ist kompliziert. Aufgrund der unachtsamen Behandlung können Sie SIGTERM Probleme mit der Anzahl der Systemthreads erkennen (OpenMP in Bibliotheken vs. Go-Laufzeit-Threads).

Deshalb habe ich beschlossen, eine Bibliothek auf pure Go für Vorhersagen mit in XGBoost oder LightGBM Modellen zu schreiben. Es heißt leaves .

Blätter

Hauptmerkmale der Bibliothek:

  • Für LightGBM Modelle
    • Lesen von Modellen aus einem Standardformat (Text)
    • Unterstützung für physische und kategoriale Attribute
    • Unterstützung für fehlende Werte
    • Optimierung der Arbeit mit kategorialen Variablen
    • Vorhersageoptimierung mit Nur-Vorhersage-Datenstrukturen

  • Für XGBoost Modelle
    • Lesen von Modellen aus einem Standardformat (binär)
    • Unterstützung für fehlende Werte
    • Vorhersageoptimierung


Hier ist ein minimales Go Programm, das ein Modell von der Festplatte lädt und eine Vorhersage anzeigt:

 package main import ( "bufio" "fmt" "os" "github.com/dmitryikh/leaves" ) func main() { // 1.     path := "lightgbm_model.txt" reader, err := os.Open(path) if err != nil { panic(err) } defer reader.Close() // 2.   LightGBM model, err := leaves.LGEnsembleFromReader(bufio.NewReader(reader)) if err != nil { panic(err) } // 3.  ! fvals := []float64{1.0, 2.0, 3.0} p := model.Predict(fvals, 0) fmt.Printf("Prediction for %v: %f\n", fvals, p) } 

Die Bibliotheks-API ist minimal. Um das XGBoost Modell zu verwenden XGBoost rufen leaves.XGEnsembleFromReader einfach die leaves.XGEnsembleFromReader Methode anstelle der obigen auf. Vorhersagen können in PredictDense durch Aufrufen der PredictDense oder model.PredictCSR . Weitere Verwendungsszenarien finden Sie in Blatttests .

Trotz der Tatsache, dass Go langsamer als C++ läuft (hauptsächlich aufgrund intensiverer Laufzeit- und Laufzeitprüfungen), konnte dank einer Reihe von Optimierungen eine Vorhersagerate erzielt werden, die mit dem Aufruf der C-API der ursprünglichen Bibliotheken vergleichbar ist.


Weitere Details zu den Ergebnissen und Vergleichsmethoden finden Sie im Repository auf github .

Siehe die Wurzel


Ich hoffe, dieser Artikel öffnet die Tür zur Implementierung von Bäumen in den XGBoost und LightGBM . Wie Sie sehen können, sind die grundlegenden Konstruktionen recht einfach, und ich ermutige die Leser, Open Source zu nutzen, um den Code zu studieren, wenn Fragen zur Funktionsweise auftreten.

Für diejenigen, die sich für das Thema der Verwendung von Gradientenverstärkungsmodellen in ihren Diensten in der Sprache Go interessieren, empfehle ich, dass Sie sich mit der Blattbibliothek vertraut machen. Mithilfe von leaves Sie ganz einfach führende Lösungen für maschinelles Lernen in Ihrer Produktionsumgebung verwenden, ohne dabei im Vergleich zu den ursprünglichen C ++ - Implementierungen an Geschwindigkeit zu verlieren.

Viel Glück!

Source: https://habr.com/ru/post/de423495/


All Articles