因此,您已经开发并训练了神经网络来执行某种任务(例如,通过摄像头识别相同的对象),并希望在Android应用程序中实现它? 然后欢迎来到凯特!
首先,应该理解,android当前仅知道如何使用TensorFlowLite格式的网络,这意味着我们需要对源网络进行一些操作。 假设您已经在Keras或Tensorflow框架上训练有素的网络。 您必须以pb格式保存网格。
让我们从在Tensorflow上编写代码的情况开始,然后一切都会变得简单一些。
saver = tf.train.Saver() tf.train.write_graph(session.graph_def, path_to_folder, "net.pb", False) tf.train.write_graph(session.graph_def, path_to_folder, "net.pbtxt", True) saver.save(session,path_to_folder+"model.ckpt")
如果使用Keras编写,则需要创建一个新的会话对象,将其链接保存在训练网络的文件的开头,然后将其传递给set_session函数。
import keras.backend as K session = K.get_session() K.set_session(session)
好了,您保存了网络,现在需要将其转换为tflite格式。 为此,我们需要运行两个小脚本,第一个“冻结”网络,第二个已经转换为所需的格式。 “冻结”的本质是tf不会将图层的权重存储在已保存的pb文件中,而是将其保存在特殊的检查点中。 为了随后转换为tflite,您需要在一个文件中提供有关神经网络的所有信息。
freeze_graph --input_binary=false --input_graph=net.pbtxt --output_node_names=result/Softmax --output_graph=frozen_graph.pb --input_checkpoint=model.ckpt
注意,您需要知道输出张量的名称。 在Tensorflow中,您可以自己设置它(如果使用Keras)-在图层构造函数中设置名称
model.add(Dense(10,activation="softmax",name="result"))
在这种情况下,张量的名称通常看起来像“结果/ Softmax”
如果不是这样,您可以找到以下名称
[print(n.name) for n in session.graph.as_graph_def().node]
它仍然可以运行第二个脚本
toco --graph_def_file=frozen-graph.pb --output_file=model.tflite --output_format=TFLITE --inference_type=FLOAT --input_arrays=input_input --output_arrays=result/Softmax --input_shapes=1,784
万岁! 现在,您的文件夹中有一个TensorFlowLite模型,由您将其正确集成到您的android应用程序中。 您可以使用新的Firebase ML Kit进行此操作,但是还有另一种方法,稍后再解决。 向我们的gradle文件添加依赖项
dependencies { // ... implementation 'com.google.firebase:firebase-ml-model-interpreter:16.2.0' }
现在,您需要确定是将模型保留在服务器上还是随应用程序一起提供。
考虑第一种情况:服务器上的模型。 首先,不要忘记添加清单
<uses-permission android:name="android.permission.INTERNET" />
如果您在本地使用应用程序中包含的模型,请不要忘记将以下条目添加到build.gradle文件中,以便不压缩模型文件
android { // ... aaptOptions { noCompress "tflite" } }
然后,类似于云中的模型,需要注册我们的本地神经元。
FirebaseLocalModelSource localSource = new FirebaseLocalModelSource.Builder("my_local_model") .setAssetFilePath("mymodel.tflite") .build(); FirebaseModelManager.getInstance().registerLocalModelSource(localSource);
上面的代码假定您的模型位于assets文件夹中,如果没有,则改为
.setAssetFilePath("mymodel.tflite")
使用
.seFilePath(filePath)
然后,我们创建新对象FirebaseModelOptions和FirebaseModelInterpreter
FirebaseModelOptions options = new FirebaseModelOptions.Builder() .setCloudModelName("my_cloud_model") .setLocalModelName("my_local_model") .build(); FirebaseModelInterpreter firebaseInterpreter = FirebaseModelInterpreter.getInstance(options);
您可以同时使用本地模型和基于服务器的模型。 在这种情况下,默认情况下将使用云(如果可用),否则将使用本地云。
几乎所有内容都将保留为输入/输出数据创建数组,然后运行!
FirebaseModelInputOutputOptions inputOutputOptions = new FirebaseModelInputOutputOptions.Builder() .setInputFormat(0, FirebaseModelDataType.BYTE, new int[]{1, 640, 480, 3}) .setOutputFormat(0, FirebaseModelDataType.FLOAT32, new int[]{1, 784}) .build(); byte[][][][] input = new byte[1][640][480][3]; input = getYourInputData(); FirebaseModelInputs inputs = new FirebaseModelInputs.Builder() .add(input)
如果出于某些原因您不想使用Firebase,则可以使用另一种方法,即调用tflite解释器并直接将其输入数据。
添加一条线来构建/ gradle
implementation 'org.tensorflow:tensorflow-lite:+'
创建一个解释器和数组
Interpreter tflite = new Interpreter(loadModelFile(getContext(), "model.tflite"));
如您所见,这种情况下的代码要少得多。
这就是在android中使用神经网络所需要的。
有用的链接:
ML Kit在码头上Tensorflow Lite