Entonces, ¿ha desarrollado y entrenado su red neuronal para realizar algún tipo de tarea (por ejemplo, el mismo reconocimiento de objetos a través de la cámara) y desea integrarlo en su aplicación de Android? Entonces bienvenido a Kat!
Para empezar, debe entenderse que el Android actualmente solo sabe cómo trabajar con redes de formato TensorFlowLite, lo que significa que necesitamos realizar algunas manipulaciones con la red de origen. Suponga que tiene una red ya capacitada en el marco de Keras o Tensorflow. Debe guardar la cuadrícula en formato pb.
Comencemos con el caso cuando escribe en Tensorflow, entonces todo es un poco más fácil.
saver = tf.train.Saver() tf.train.write_graph(session.graph_def, path_to_folder, "net.pb", False) tf.train.write_graph(session.graph_def, path_to_folder, "net.pbtxt", True) saver.save(session,path_to_folder+"model.ckpt")
Si escribe en Keras, debe crear un nuevo objeto de sesión, guardar un enlace al principio del archivo donde entrena la red y pasarlo a la función set_session
import keras.backend as K session = K.get_session() K.set_session(session)
Bueno, guardó la red, ahora necesita convertirla al formato tflite. Para hacer esto, necesitamos ejecutar dos pequeños scripts, el primero "congela" la red, el segundo ya se traducirá al formato deseado. La esencia de la "congelación" es que tf no almacena el peso de las capas en el archivo pb guardado, sino que las guarda en puntos de control especiales. Para la conversión posterior a tflite, necesita toda la información sobre la red neuronal en un solo archivo.
freeze_graph --input_binary=false --input_graph=net.pbtxt --output_node_names=result/Softmax --output_graph=frozen_graph.pb --input_checkpoint=model.ckpt
Tenga en cuenta que necesita saber el nombre del tensor de salida. En tensorflow puede configurarlo usted mismo, en caso de usar Keras: configure el nombre en el constructor de capas
model.add(Dense(10,activation="softmax",name="result"))
En este caso, el nombre del tensor generalmente se ve como "resultado / Softmax"
Si en su caso no es así, puede encontrar el nombre de la siguiente manera
[print(n.name) for n in session.graph.as_graph_def().node]
Queda por ejecutar el segundo script
toco --graph_def_file=frozen-graph.pb --output_file=model.tflite --output_format=TFLITE --inference_type=FLOAT --input_arrays=input_input --output_arrays=result/Softmax --input_shapes=1,784
¡Hurra! Ahora que tiene un modelo TensorFlowLite en su carpeta, depende de usted integrarlo correctamente en su aplicación de Android. Puedes hacer esto con el nuevo Kit Firebase ML, pero hay otra forma, un poco más adelante. Agregue una dependencia a nuestro archivo gradle
dependencies { // ... implementation 'com.google.firebase:firebase-ml-model-interpreter:16.2.0' }
Ahora debe decidir si mantendrá el modelo en algún lugar de su servidor o si lo enviará con la aplicación.
Considere el primer caso: un modelo en el servidor. Antes que nada, no olvides agregar al manifiesto
<uses-permission android:name="android.permission.INTERNET" />
Si utiliza el modelo incluido en la aplicación localmente, no olvide agregar la siguiente entrada al archivo build.gradle para que el archivo del modelo no esté comprimido
android { // ... aaptOptions { noCompress "tflite" } }
Después de lo cual, por analogía con el modelo en la nube, nuestra neurona local necesita ser registrada.
FirebaseLocalModelSource localSource = new FirebaseLocalModelSource.Builder("my_local_model") .setAssetFilePath("mymodel.tflite") .build(); FirebaseModelManager.getInstance().registerLocalModelSource(localSource);
El código anterior supone que su modelo está en la carpeta de activos, si no, en su lugar
.setAssetFilePath("mymodel.tflite")
usar
.seFilePath(filePath)
Luego creamos nuevos objetos FirebaseModelOptions y FirebaseModelInterpreter
FirebaseModelOptions options = new FirebaseModelOptions.Builder() .setCloudModelName("my_cloud_model") .setLocalModelName("my_local_model") .build(); FirebaseModelInterpreter firebaseInterpreter = FirebaseModelInterpreter.getInstance(options);
Puede usar modelos locales y basados en servidor al mismo tiempo. En este caso, la nube se utilizará de forma predeterminada, si está disponible, de lo contrario, local.
¡Casi todo, queda por crear matrices para datos de entrada / salida, y ejecutar!
FirebaseModelInputOutputOptions inputOutputOptions = new FirebaseModelInputOutputOptions.Builder() .setInputFormat(0, FirebaseModelDataType.BYTE, new int[]{1, 640, 480, 3}) .setOutputFormat(0, FirebaseModelDataType.FLOAT32, new int[]{1, 784}) .build(); byte[][][][] input = new byte[1][640][480][3]; input = getYourInputData(); FirebaseModelInputs inputs = new FirebaseModelInputs.Builder() .add(input)
Si por alguna razón no desea usar Firebase, hay otra forma, llamando al intérprete de tflite y alimentando sus datos directamente.
Agregar una línea para construir / gradle
implementation 'org.tensorflow:tensorflow-lite:+'
Crear un intérprete y matrices.
Interpreter tflite = new Interpreter(loadModelFile(getContext(), "model.tflite"));
El código en este caso es mucho menor, como puede ver.
Eso es todo lo que necesitas para usar tu red neuronal en Android.
Enlaces utiles:
Fuera de muelles por ML KitTensorflow Lite