Neuronale Netze in Android, Google ML Kit und nicht nur

Sie haben also Ihr neuronales Netzwerk entwickelt und trainiert, um eine bestimmte Aufgabe auszuführen (z. B. dieselbe Objekterkennung über die Kamera), und möchten es in Ihre Android-Anwendung integrieren? Dann willkommen bei kat!

Zunächst sollte klar sein, dass der Android derzeit nur mit Netzwerken im TensorFlowLite-Format arbeiten kann, was bedeutet, dass wir einige Manipulationen mit dem Quellnetzwerk durchführen müssen. Angenommen, Sie haben ein bereits geschultes Netzwerk im Keras- oder Tensorflow-Framework. Sie müssen das Raster im pb-Format speichern.

Beginnen wir mit dem Fall, wenn Sie auf Tensorflow schreiben, dann ist alles etwas einfacher.

saver = tf.train.Saver() tf.train.write_graph(session.graph_def, path_to_folder, "net.pb", False) tf.train.write_graph(session.graph_def, path_to_folder, "net.pbtxt", True) saver.save(session,path_to_folder+"model.ckpt") 

Wenn Sie in Keras schreiben, müssen Sie ein neues Sitzungsobjekt erstellen, einen Link dazu am Anfang der Datei speichern, in der Sie das Netzwerk trainieren, und es an die Funktion set_session übergeben

 import keras.backend as K session = K.get_session() K.set_session(session) 

Nun, Sie haben das Netzwerk gespeichert, jetzt müssen Sie es in das tflite-Format konvertieren. Dazu müssen wir zwei kleine Skripte ausführen, das erste „friert“ das Netzwerk ein, das zweite wird bereits in das gewünschte Format übersetzt. Das Wesentliche beim „Einfrieren“ ist, dass tf das Gewicht der Ebenen nicht in der gespeicherten pb-Datei speichert, sondern an speziellen Prüfpunkten speichert. Für die anschließende Konvertierung in tflite benötigen Sie alle Informationen zum neuronalen Netzwerk in einer Datei.

 freeze_graph --input_binary=false --input_graph=net.pbtxt --output_node_names=result/Softmax --output_graph=frozen_graph.pb --input_checkpoint=model.ckpt 

Beachten Sie, dass Sie den Namen des Ausgangstensors kennen müssen. In Tensorflow können Sie es selbst festlegen, wenn Sie Keras verwenden - legen Sie den Namen im Ebenenkonstruktor fest

 model.add(Dense(10,activation="softmax",name="result")) 

In diesem Fall sieht der Name des Tensors normalerweise wie folgt aus: "Ergebnis / Softmax"

Wenn dies in Ihrem Fall nicht der Fall ist, können Sie den Namen wie folgt finden

 [print(n.name) for n in session.graph.as_graph_def().node] 

Es bleibt das zweite Skript auszuführen

 toco --graph_def_file=frozen-graph.pb --output_file=model.tflite --output_format=TFLITE --inference_type=FLOAT --input_arrays=input_input --output_arrays=result/Softmax --input_shapes=1,784 

Hurra! Jetzt haben Sie ein TensorFlowLite-Modell in Ihrem Ordner. Es liegt an Ihnen, es korrekt in Ihre Android-Anwendung zu integrieren. Sie können dies mit dem neuen Firebase ML Kit tun, aber es gibt einen anderen Weg, etwas später. Fügen Sie unserer Gradle-Datei eine Abhängigkeit hinzu

 dependencies { // ... implementation 'com.google.firebase:firebase-ml-model-interpreter:16.2.0' } 

Jetzt müssen Sie entscheiden, ob Sie das Modell irgendwo auf Ihrem Server behalten oder mit der Anwendung versenden möchten.

Betrachten Sie den ersten Fall: ein Modell auf dem Server. Vergessen Sie zunächst nicht, das Manifest zu ergänzen

 <uses-permission android:name="android.permission.INTERNET" /> 

  //      ,   /  FirebaseModelDownloadConditions.Builder conditionsBuilder = new FirebaseModelDownloadConditions.Builder().requireWifi(); if (Build.VERSION.SDK_INT >= Build.VERSION_CODES.N) { conditionsBuilder = conditionsBuilder .requireCharging(); } FirebaseModelDownloadConditions conditions = conditionsBuilder.build(); //   FirebaseCloudModelSource ,   (    ,  //   Firebase) FirebaseCloudModelSource cloudSource = new FirebaseCloudModelSource.Builder("my_cloud_model") .enableModelUpdates(true) .setInitialDownloadConditions(conditions) .setUpdatesDownloadConditions(conditions) .build(); FirebaseModelManager.getInstance().registerCloudModelSource(cloudSource); 

Wenn Sie das in der Anwendung enthaltene Modell lokal verwenden, vergessen Sie nicht, der Datei build.gradle den folgenden Eintrag hinzuzufügen, damit die Modelldatei nicht komprimiert wird

 android { // ... aaptOptions { noCompress "tflite" } } 

Danach muss analog zum Modell in der Cloud unser lokales Neuron registriert werden.

 FirebaseLocalModelSource localSource = new FirebaseLocalModelSource.Builder("my_local_model") .setAssetFilePath("mymodel.tflite") .build(); FirebaseModelManager.getInstance().registerLocalModelSource(localSource); 

Der obige Code setzt voraus, dass sich Ihr Modell im Assets-Ordner befindet, wenn nicht, stattdessen

  .setAssetFilePath("mymodel.tflite") 

verwenden

  .seFilePath(filePath) 

Dann erstellen wir neue Objekte FirebaseModelOptions und FirebaseModelInterpreter

 FirebaseModelOptions options = new FirebaseModelOptions.Builder() .setCloudModelName("my_cloud_model") .setLocalModelName("my_local_model") .build(); FirebaseModelInterpreter firebaseInterpreter = FirebaseModelInterpreter.getInstance(options); 

Sie können sowohl lokale als auch serverbasierte Modelle gleichzeitig verwenden. In diesem Fall wird die Cloud standardmäßig verwendet, sofern verfügbar, andernfalls lokal.

Fast alles bleibt, um Arrays für Eingabe- / Ausgabedaten zu erstellen und auszuführen!

 FirebaseModelInputOutputOptions inputOutputOptions = new FirebaseModelInputOutputOptions.Builder() .setInputFormat(0, FirebaseModelDataType.BYTE, new int[]{1, 640, 480, 3}) .setOutputFormat(0, FirebaseModelDataType.FLOAT32, new int[]{1, 784}) .build(); byte[][][][] input = new byte[1][640][480][3]; input = getYourInputData(); FirebaseModelInputs inputs = new FirebaseModelInputs.Builder() .add(input) // add() as many input arrays as your model requires .build(); Task<FirebaseModelOutputs> result = firebaseInterpreter.run(inputs, inputOutputOptions) .addOnSuccessListener( new OnSuccessListener<FirebaseModelOutputs>() { @Override public void onSuccess(FirebaseModelOutputs result) { // ... } }) .addOnFailureListener( new OnFailureListener() { @Override public void onFailure(@NonNull Exception e) { // Task failed with an exception // ... } }); float[][] output = result.<float[][]>getOutput(0); float[] probabilities = output[0]; 

Wenn Sie Firebase aus irgendeinem Grund nicht verwenden möchten, gibt es eine andere Möglichkeit: Rufen Sie den tflite-Interpreter auf und geben Sie die Daten direkt ein.

Fügen Sie eine Linie zum Erstellen / Gradle hinzu

  implementation 'org.tensorflow:tensorflow-lite:+' 

Erstellen Sie einen Interpreter und Arrays

  Interpreter tflite = new Interpreter(loadModelFile(getContext(), "model.tflite")); //     inputs tflite.run(inputs,outputs) 

Der Code ist in diesem Fall viel weniger, wie Sie sehen können.

Das ist alles, was Sie brauchen, um Ihr neuronales Netzwerk in Android zu verwenden.

Nützliche Links:

Off Docks von ML Kit
Tensorflow Lite

Source: https://habr.com/ru/post/de422041/


All Articles