Réseaux de neurones dans Android, Google ML Kit et pas seulement

Vous avez donc développé et formé votre réseau de neurones pour effectuer une sorte de tâche (par exemple, la même reconnaissance d'objet via la caméra) et vous souhaitez l'implémenter dans votre application Android? Alors bienvenue à Kat!

Pour commencer, il faut comprendre que l'androïde ne sait actuellement que travailler avec les réseaux au format TensorFlowLite, ce qui signifie que nous devons effectuer certaines manipulations avec le réseau source. Supposons que vous disposiez d'un réseau déjà formé sur l'infrastructure Keras ou Tensorflow. Vous devez enregistrer la grille au format pb.

Commençons par le cas lorsque vous écrivez sur Tensorflow, alors tout est un peu plus facile.

saver = tf.train.Saver() tf.train.write_graph(session.graph_def, path_to_folder, "net.pb", False) tf.train.write_graph(session.graph_def, path_to_folder, "net.pbtxt", True) saver.save(session,path_to_folder+"model.ckpt") 

Si vous écrivez dans Keras, vous devez créer un nouvel objet de session, enregistrer un lien vers celui-ci au début du fichier dans lequel vous entraînez le réseau et le transmettre à la fonction set_session

 import keras.backend as K session = K.get_session() K.set_session(session) 

Eh bien, vous avez enregistré le réseau, vous devez maintenant le convertir au format tflite. Pour ce faire, nous devons exécuter deux petits scripts, le premier «fige» le réseau, le second se traduira déjà au format souhaité. L'essence du «gel» est que tf ne stocke pas le poids des couches dans le fichier pb enregistré, mais les enregistre dans des points de contrôle spéciaux. Pour une conversion ultérieure en tflite, vous avez besoin de toutes les informations sur le réseau neuronal dans un seul fichier.

 freeze_graph --input_binary=false --input_graph=net.pbtxt --output_node_names=result/Softmax --output_graph=frozen_graph.pb --input_checkpoint=model.ckpt 

Notez que vous devez connaître le nom du tenseur de sortie. Dans tensorflow, vous pouvez le définir vous-même, en cas d'utilisation de Keras - définissez le nom dans le constructeur de couche

 model.add(Dense(10,activation="softmax",name="result")) 

Dans ce cas, le nom du tenseur ressemble généralement à «résultat / Softmax»

Si dans votre cas, ce n'est pas le cas, vous pouvez trouver le nom comme suit

 [print(n.name) for n in session.graph.as_graph_def().node] 

Il reste à exécuter le deuxième script

 toco --graph_def_file=frozen-graph.pb --output_file=model.tflite --output_format=TFLITE --inference_type=FLOAT --input_arrays=input_input --output_arrays=result/Softmax --input_shapes=1,784 

Hourra! Maintenant que vous avez un modèle TensorFlowLite dans votre dossier, c'est à vous de l'intégrer correctement dans votre application Android. Vous pouvez le faire avec le nouveau kit Firebase ML, mais il existe un autre moyen, à ce sujet un peu plus tard. Ajoutez une dépendance à notre fichier gradle

 dependencies { // ... implementation 'com.google.firebase:firebase-ml-model-interpreter:16.2.0' } 

Vous devez maintenant décider si vous conserverez le modèle quelque part sur votre serveur ou si vous le livrez avec l'application.

Prenons le premier cas: un modèle sur le serveur. Tout d'abord, n'oubliez pas d'ajouter au manifeste

 <uses-permission android:name="android.permission.INTERNET" /> 

  //      ,   /  FirebaseModelDownloadConditions.Builder conditionsBuilder = new FirebaseModelDownloadConditions.Builder().requireWifi(); if (Build.VERSION.SDK_INT >= Build.VERSION_CODES.N) { conditionsBuilder = conditionsBuilder .requireCharging(); } FirebaseModelDownloadConditions conditions = conditionsBuilder.build(); //   FirebaseCloudModelSource ,   (    ,  //   Firebase) FirebaseCloudModelSource cloudSource = new FirebaseCloudModelSource.Builder("my_cloud_model") .enableModelUpdates(true) .setInitialDownloadConditions(conditions) .setUpdatesDownloadConditions(conditions) .build(); FirebaseModelManager.getInstance().registerCloudModelSource(cloudSource); 

Si vous utilisez localement le modèle inclus dans l'application, n'oubliez pas d'ajouter l'entrée suivante au fichier build.gradle pour que le fichier modèle ne soit pas compressé

 android { // ... aaptOptions { noCompress "tflite" } } 

Après quoi, par analogie avec le modèle dans le nuage, notre neurone local doit être enregistré.

 FirebaseLocalModelSource localSource = new FirebaseLocalModelSource.Builder("my_local_model") .setAssetFilePath("mymodel.tflite") .build(); FirebaseModelManager.getInstance().registerLocalModelSource(localSource); 

Le code ci-dessus suppose que votre modèle se trouve dans le dossier des ressources, sinon, à la place

  .setAssetFilePath("mymodel.tflite") 

utiliser

  .seFilePath(filePath) 

Ensuite, nous créons de nouveaux objets FirebaseModelOptions et FirebaseModelInterpreter

 FirebaseModelOptions options = new FirebaseModelOptions.Builder() .setCloudModelName("my_cloud_model") .setLocalModelName("my_local_model") .build(); FirebaseModelInterpreter firebaseInterpreter = FirebaseModelInterpreter.getInstance(options); 

Vous pouvez utiliser à la fois des modèles locaux et basés sur un serveur. Dans ce cas, le cloud sera utilisé par défaut, s'il est disponible, sinon local.

Presque tout, il reste à créer des tableaux pour les données d'entrée / sortie, et à exécuter!

 FirebaseModelInputOutputOptions inputOutputOptions = new FirebaseModelInputOutputOptions.Builder() .setInputFormat(0, FirebaseModelDataType.BYTE, new int[]{1, 640, 480, 3}) .setOutputFormat(0, FirebaseModelDataType.FLOAT32, new int[]{1, 784}) .build(); byte[][][][] input = new byte[1][640][480][3]; input = getYourInputData(); FirebaseModelInputs inputs = new FirebaseModelInputs.Builder() .add(input) // add() as many input arrays as your model requires .build(); Task<FirebaseModelOutputs> result = firebaseInterpreter.run(inputs, inputOutputOptions) .addOnSuccessListener( new OnSuccessListener<FirebaseModelOutputs>() { @Override public void onSuccess(FirebaseModelOutputs result) { // ... } }) .addOnFailureListener( new OnFailureListener() { @Override public void onFailure(@NonNull Exception e) { // Task failed with an exception // ... } }); float[][] output = result.<float[][]>getOutput(0); float[] probabilities = output[0]; 

Si, pour une raison quelconque, vous ne souhaitez pas utiliser Firebase, il existe un autre moyen, en appelant l'interpréteur tflite et en lui fournissant directement des données.

Ajouter une ligne à construire / gradle

  implementation 'org.tensorflow:tensorflow-lite:+' 

Créer un interprète et des tableaux

  Interpreter tflite = new Interpreter(loadModelFile(getContext(), "model.tflite")); //     inputs tflite.run(inputs,outputs) 

Le code dans ce cas est beaucoup moins, comme vous pouvez le voir.

C'est tout ce dont vous avez besoin pour utiliser votre réseau de neurones dans Android.

Liens utiles:

Off docks par ML Kit
Tensorflow Lite

Source: https://habr.com/ru/post/fr422041/


All Articles