Sie haben also Ihr neuronales Netzwerk entwickelt und trainiert, um eine bestimmte Aufgabe auszuführen (z. B. dieselbe Objekterkennung über die Kamera), und möchten es in Ihre Android-Anwendung integrieren? Dann willkommen bei kat!
Zunächst sollte klar sein, dass der Android derzeit nur mit Netzwerken im TensorFlowLite-Format arbeiten kann, was bedeutet, dass wir einige Manipulationen mit dem Quellnetzwerk durchführen müssen. Angenommen, Sie haben ein bereits geschultes Netzwerk im Keras- oder Tensorflow-Framework. Sie müssen das Raster im pb-Format speichern.
Beginnen wir mit dem Fall, wenn Sie auf Tensorflow schreiben, dann ist alles etwas einfacher.
saver = tf.train.Saver() tf.train.write_graph(session.graph_def, path_to_folder, "net.pb", False) tf.train.write_graph(session.graph_def, path_to_folder, "net.pbtxt", True) saver.save(session,path_to_folder+"model.ckpt")
Wenn Sie in Keras schreiben, müssen Sie ein neues Sitzungsobjekt erstellen, einen Link dazu am Anfang der Datei speichern, in der Sie das Netzwerk trainieren, und es an die Funktion set_session übergeben
import keras.backend as K session = K.get_session() K.set_session(session)
Nun, Sie haben das Netzwerk gespeichert, jetzt müssen Sie es in das tflite-Format konvertieren. Dazu müssen wir zwei kleine Skripte ausführen, das erste „friert“ das Netzwerk ein, das zweite wird bereits in das gewünschte Format übersetzt. Das Wesentliche beim „Einfrieren“ ist, dass tf das Gewicht der Ebenen nicht in der gespeicherten pb-Datei speichert, sondern an speziellen Prüfpunkten speichert. Für die anschließende Konvertierung in tflite benötigen Sie alle Informationen zum neuronalen Netzwerk in einer Datei.
freeze_graph --input_binary=false --input_graph=net.pbtxt --output_node_names=result/Softmax --output_graph=frozen_graph.pb --input_checkpoint=model.ckpt
Beachten Sie, dass Sie den Namen des Ausgangstensors kennen müssen. In Tensorflow können Sie es selbst festlegen, wenn Sie Keras verwenden - legen Sie den Namen im Ebenenkonstruktor fest
model.add(Dense(10,activation="softmax",name="result"))
In diesem Fall sieht der Name des Tensors normalerweise wie folgt aus: "Ergebnis / Softmax"
Wenn dies in Ihrem Fall nicht der Fall ist, können Sie den Namen wie folgt finden
[print(n.name) for n in session.graph.as_graph_def().node]
Es bleibt das zweite Skript auszuführen
toco --graph_def_file=frozen-graph.pb --output_file=model.tflite --output_format=TFLITE --inference_type=FLOAT --input_arrays=input_input --output_arrays=result/Softmax --input_shapes=1,784
Hurra! Jetzt haben Sie ein TensorFlowLite-Modell in Ihrem Ordner. Es liegt an Ihnen, es korrekt in Ihre Android-Anwendung zu integrieren. Sie können dies mit dem neuen Firebase ML Kit tun, aber es gibt einen anderen Weg, etwas später. Fügen Sie unserer Gradle-Datei eine Abhängigkeit hinzu
dependencies { // ... implementation 'com.google.firebase:firebase-ml-model-interpreter:16.2.0' }
Jetzt müssen Sie entscheiden, ob Sie das Modell irgendwo auf Ihrem Server behalten oder mit der Anwendung versenden möchten.
Betrachten Sie den ersten Fall: ein Modell auf dem Server. Vergessen Sie zunächst nicht, das Manifest zu ergänzen
<uses-permission android:name="android.permission.INTERNET" />
Wenn Sie das in der Anwendung enthaltene Modell lokal verwenden, vergessen Sie nicht, der Datei build.gradle den folgenden Eintrag hinzuzufügen, damit die Modelldatei nicht komprimiert wird
android { // ... aaptOptions { noCompress "tflite" } }
Danach muss analog zum Modell in der Cloud unser lokales Neuron registriert werden.
FirebaseLocalModelSource localSource = new FirebaseLocalModelSource.Builder("my_local_model") .setAssetFilePath("mymodel.tflite") .build(); FirebaseModelManager.getInstance().registerLocalModelSource(localSource);
Der obige Code setzt voraus, dass sich Ihr Modell im Assets-Ordner befindet, wenn nicht, stattdessen
.setAssetFilePath("mymodel.tflite")
verwenden
.seFilePath(filePath)
Dann erstellen wir neue Objekte FirebaseModelOptions und FirebaseModelInterpreter
FirebaseModelOptions options = new FirebaseModelOptions.Builder() .setCloudModelName("my_cloud_model") .setLocalModelName("my_local_model") .build(); FirebaseModelInterpreter firebaseInterpreter = FirebaseModelInterpreter.getInstance(options);
Sie können sowohl lokale als auch serverbasierte Modelle gleichzeitig verwenden. In diesem Fall wird die Cloud standardmäßig verwendet, sofern verfügbar, andernfalls lokal.
Fast alles bleibt, um Arrays für Eingabe- / Ausgabedaten zu erstellen und auszuführen!
FirebaseModelInputOutputOptions inputOutputOptions = new FirebaseModelInputOutputOptions.Builder() .setInputFormat(0, FirebaseModelDataType.BYTE, new int[]{1, 640, 480, 3}) .setOutputFormat(0, FirebaseModelDataType.FLOAT32, new int[]{1, 784}) .build(); byte[][][][] input = new byte[1][640][480][3]; input = getYourInputData(); FirebaseModelInputs inputs = new FirebaseModelInputs.Builder() .add(input)
Wenn Sie Firebase aus irgendeinem Grund nicht verwenden möchten, gibt es eine andere Möglichkeit: Rufen Sie den tflite-Interpreter auf und geben Sie die Daten direkt ein.
Fügen Sie eine Linie zum Erstellen / Gradle hinzu
implementation 'org.tensorflow:tensorflow-lite:+'
Erstellen Sie einen Interpreter und Arrays
Interpreter tflite = new Interpreter(loadModelFile(getContext(), "model.tflite"));
Der Code ist in diesem Fall viel weniger, wie Sie sehen können.
Das ist alles, was Sie brauchen, um Ihr neuronales Netzwerk in Android zu verwenden.
Nützliche Links:
Off Docks von ML KitTensorflow Lite