Android和Google ML Kit中的神经网络

因此,您已经开发并训练了神经网络来执行某种任务(例如,通过摄像头识别相同的对象),并希望在Android应用程序中实现它? 然后欢迎来到凯特!

首先,应该理解,android当前仅知道如何使用TensorFlowLite格式的网络,这意味着我们需要对源网络进行一些操作。 假设您已经在Keras或Tensorflow框架上训练有素的网络。 您必须以pb格式保存网格。

让我们从在Tensorflow上编写代码的情况开始,然后一切都会变得简单一些。

saver = tf.train.Saver() tf.train.write_graph(session.graph_def, path_to_folder, "net.pb", False) tf.train.write_graph(session.graph_def, path_to_folder, "net.pbtxt", True) saver.save(session,path_to_folder+"model.ckpt") 

如果使用Keras编写,则需要创建一个新的会话对象,将其链接保存在训练网络的文件的开头,然后将其传递给set_session函数。

 import keras.backend as K session = K.get_session() K.set_session(session) 

好了,您保存了网络,现在需要将其转换为tflite格式。 为此,我们需要运行两个小脚本,第一个“冻结”网络,第二个已经转换为所需的格式。 “冻结”的本质是tf不会将图层的权重存储在已保存的pb文件中,而是将其保存在特殊的检查点中。 为了随后转换为tflite,您需要在一个文件中提供有关神经网络的所有信息。

 freeze_graph --input_binary=false --input_graph=net.pbtxt --output_node_names=result/Softmax --output_graph=frozen_graph.pb --input_checkpoint=model.ckpt 

注意,您需要知道输出张量的名称。 在Tensorflow中,您可以自己设置它(如果使用Keras)-在图层构造函数中设置名称

 model.add(Dense(10,activation="softmax",name="result")) 

在这种情况下,张量的名称通常看起来像“结果/ Softmax”

如果不是这样,您可以找到以下名称

 [print(n.name) for n in session.graph.as_graph_def().node] 

它仍然可以运行第二个脚本

 toco --graph_def_file=frozen-graph.pb --output_file=model.tflite --output_format=TFLITE --inference_type=FLOAT --input_arrays=input_input --output_arrays=result/Softmax --input_shapes=1,784 

万岁! 现在,您的文件夹中有一个TensorFlowLite模型,由您将其正确集成到您的android应用程序中。 您可以使用新的Firebase ML Kit进行此操作,但是还有另一种方法,稍后再解决。 向我们的gradle文件添加依赖项

 dependencies { // ... implementation 'com.google.firebase:firebase-ml-model-interpreter:16.2.0' } 

现在,您需要确定是将模型保留在服务器上还是随应用程序一起提供。

考虑第一种情况:服务器上的模型。 首先,不要忘记添加清单

 <uses-permission android:name="android.permission.INTERNET" /> 

  //      ,   /  FirebaseModelDownloadConditions.Builder conditionsBuilder = new FirebaseModelDownloadConditions.Builder().requireWifi(); if (Build.VERSION.SDK_INT >= Build.VERSION_CODES.N) { conditionsBuilder = conditionsBuilder .requireCharging(); } FirebaseModelDownloadConditions conditions = conditionsBuilder.build(); //   FirebaseCloudModelSource ,   (    ,  //   Firebase) FirebaseCloudModelSource cloudSource = new FirebaseCloudModelSource.Builder("my_cloud_model") .enableModelUpdates(true) .setInitialDownloadConditions(conditions) .setUpdatesDownloadConditions(conditions) .build(); FirebaseModelManager.getInstance().registerCloudModelSource(cloudSource); 

如果您在本地使用应用程序中包含的模型,请不要忘记将以下条目添加到build.gradle文件中,以便不压缩模型文件

 android { // ... aaptOptions { noCompress "tflite" } } 

然后,类似于云中的模型,需要注册我们的本地神经元。

 FirebaseLocalModelSource localSource = new FirebaseLocalModelSource.Builder("my_local_model") .setAssetFilePath("mymodel.tflite") .build(); FirebaseModelManager.getInstance().registerLocalModelSource(localSource); 

上面的代码假定您的模型位于assets文件夹中,如果没有,则改为

  .setAssetFilePath("mymodel.tflite") 

使用

  .seFilePath(filePath) 

然后,我们创建新对象FirebaseModelOptions和FirebaseModelInterpreter

 FirebaseModelOptions options = new FirebaseModelOptions.Builder() .setCloudModelName("my_cloud_model") .setLocalModelName("my_local_model") .build(); FirebaseModelInterpreter firebaseInterpreter = FirebaseModelInterpreter.getInstance(options); 

您可以同时使用本地模型和基于服务器的模型。 在这种情况下,默认情况下将使用云(如果可用),否则将使用本地云。

几乎所有内容都将保留为输入/输出数据创建数组,然后运行!

 FirebaseModelInputOutputOptions inputOutputOptions = new FirebaseModelInputOutputOptions.Builder() .setInputFormat(0, FirebaseModelDataType.BYTE, new int[]{1, 640, 480, 3}) .setOutputFormat(0, FirebaseModelDataType.FLOAT32, new int[]{1, 784}) .build(); byte[][][][] input = new byte[1][640][480][3]; input = getYourInputData(); FirebaseModelInputs inputs = new FirebaseModelInputs.Builder() .add(input) // add() as many input arrays as your model requires .build(); Task<FirebaseModelOutputs> result = firebaseInterpreter.run(inputs, inputOutputOptions) .addOnSuccessListener( new OnSuccessListener<FirebaseModelOutputs>() { @Override public void onSuccess(FirebaseModelOutputs result) { // ... } }) .addOnFailureListener( new OnFailureListener() { @Override public void onFailure(@NonNull Exception e) { // Task failed with an exception // ... } }); float[][] output = result.<float[][]>getOutput(0); float[] probabilities = output[0]; 

如果出于某些原因您不想使用Firebase,则可以使用另一种方法,即调用tflite解释器并直接将其输入数据。

添加一条线来构建/ gradle

  implementation 'org.tensorflow:tensorflow-lite:+' 

创建一个解释器和数组

  Interpreter tflite = new Interpreter(loadModelFile(getContext(), "model.tflite")); //     inputs tflite.run(inputs,outputs) 

如您所见,这种情况下的代码要少得多。

这就是在android中使用神经网络所需要的。

有用的链接:

ML Kit在码头上
Tensorflow Lite

Source: https://habr.com/ru/post/zh-CN422041/


All Articles