Vous avez donc développé et formé votre réseau de neurones pour effectuer une sorte de tâche (par exemple, la même reconnaissance d'objet via la caméra) et vous souhaitez l'implémenter dans votre application Android? Alors bienvenue à Kat!
Pour commencer, il faut comprendre que l'androïde ne sait actuellement que travailler avec les réseaux au format TensorFlowLite, ce qui signifie que nous devons effectuer certaines manipulations avec le réseau source. Supposons que vous disposiez d'un réseau déjà formé sur l'infrastructure Keras ou Tensorflow. Vous devez enregistrer la grille au format pb.
Commençons par le cas lorsque vous écrivez sur Tensorflow, alors tout est un peu plus facile.
saver = tf.train.Saver() tf.train.write_graph(session.graph_def, path_to_folder, "net.pb", False) tf.train.write_graph(session.graph_def, path_to_folder, "net.pbtxt", True) saver.save(session,path_to_folder+"model.ckpt")
Si vous écrivez dans Keras, vous devez créer un nouvel objet de session, enregistrer un lien vers celui-ci au début du fichier dans lequel vous entraînez le réseau et le transmettre à la fonction set_session
import keras.backend as K session = K.get_session() K.set_session(session)
Eh bien, vous avez enregistré le réseau, vous devez maintenant le convertir au format tflite. Pour ce faire, nous devons exécuter deux petits scripts, le premier «fige» le réseau, le second se traduira déjà au format souhaité. L'essence du «gel» est que tf ne stocke pas le poids des couches dans le fichier pb enregistré, mais les enregistre dans des points de contrôle spéciaux. Pour une conversion ultérieure en tflite, vous avez besoin de toutes les informations sur le réseau neuronal dans un seul fichier.
freeze_graph --input_binary=false --input_graph=net.pbtxt --output_node_names=result/Softmax --output_graph=frozen_graph.pb --input_checkpoint=model.ckpt
Notez que vous devez connaître le nom du tenseur de sortie. Dans tensorflow, vous pouvez le définir vous-même, en cas d'utilisation de Keras - définissez le nom dans le constructeur de couche
model.add(Dense(10,activation="softmax",name="result"))
Dans ce cas, le nom du tenseur ressemble généralement à «résultat / Softmax»
Si dans votre cas, ce n'est pas le cas, vous pouvez trouver le nom comme suit
[print(n.name) for n in session.graph.as_graph_def().node]
Il reste à exécuter le deuxième script
toco --graph_def_file=frozen-graph.pb --output_file=model.tflite --output_format=TFLITE --inference_type=FLOAT --input_arrays=input_input --output_arrays=result/Softmax --input_shapes=1,784
Hourra! Maintenant que vous avez un modèle TensorFlowLite dans votre dossier, c'est à vous de l'intégrer correctement dans votre application Android. Vous pouvez le faire avec le nouveau kit Firebase ML, mais il existe un autre moyen, à ce sujet un peu plus tard. Ajoutez une dépendance à notre fichier gradle
dependencies { // ... implementation 'com.google.firebase:firebase-ml-model-interpreter:16.2.0' }
Vous devez maintenant décider si vous conserverez le modèle quelque part sur votre serveur ou si vous le livrez avec l'application.
Prenons le premier cas: un modèle sur le serveur. Tout d'abord, n'oubliez pas d'ajouter au manifeste
<uses-permission android:name="android.permission.INTERNET" />
Si vous utilisez localement le modèle inclus dans l'application, n'oubliez pas d'ajouter l'entrée suivante au fichier build.gradle pour que le fichier modèle ne soit pas compressé
android { // ... aaptOptions { noCompress "tflite" } }
Après quoi, par analogie avec le modèle dans le nuage, notre neurone local doit être enregistré.
FirebaseLocalModelSource localSource = new FirebaseLocalModelSource.Builder("my_local_model") .setAssetFilePath("mymodel.tflite") .build(); FirebaseModelManager.getInstance().registerLocalModelSource(localSource);
Le code ci-dessus suppose que votre modèle se trouve dans le dossier des ressources, sinon, à la place
.setAssetFilePath("mymodel.tflite")
utiliser
.seFilePath(filePath)
Ensuite, nous créons de nouveaux objets FirebaseModelOptions et FirebaseModelInterpreter
FirebaseModelOptions options = new FirebaseModelOptions.Builder() .setCloudModelName("my_cloud_model") .setLocalModelName("my_local_model") .build(); FirebaseModelInterpreter firebaseInterpreter = FirebaseModelInterpreter.getInstance(options);
Vous pouvez utiliser à la fois des modèles locaux et basés sur un serveur. Dans ce cas, le cloud sera utilisé par défaut, s'il est disponible, sinon local.
Presque tout, il reste à créer des tableaux pour les données d'entrée / sortie, et à exécuter!
FirebaseModelInputOutputOptions inputOutputOptions = new FirebaseModelInputOutputOptions.Builder() .setInputFormat(0, FirebaseModelDataType.BYTE, new int[]{1, 640, 480, 3}) .setOutputFormat(0, FirebaseModelDataType.FLOAT32, new int[]{1, 784}) .build(); byte[][][][] input = new byte[1][640][480][3]; input = getYourInputData(); FirebaseModelInputs inputs = new FirebaseModelInputs.Builder() .add(input)
Si, pour une raison quelconque, vous ne souhaitez pas utiliser Firebase, il existe un autre moyen, en appelant l'interpréteur tflite et en lui fournissant directement des données.
Ajouter une ligne à construire / gradle
implementation 'org.tensorflow:tensorflow-lite:+'
Créer un interprète et des tableaux
Interpreter tflite = new Interpreter(loadModelFile(getContext(), "model.tflite"));
Le code dans ce cas est beaucoup moins, comme vous pouvez le voir.
C'est tout ce dont vous avez besoin pour utiliser votre réseau de neurones dans Android.
Liens utiles:
Off docks par ML KitTensorflow Lite