Jaringan saraf di Android, Google ML Kit dan tidak hanya

Jadi, Anda telah mengembangkan dan melatih jaringan saraf Anda untuk melakukan beberapa jenis tugas (misalnya, pengenalan objek yang sama melalui kamera) dan ingin mengimplementasikannya dalam aplikasi android Anda? Kalau begitu selamat datang di kat!

Untuk memulainya, harus dipahami bahwa android saat ini hanya tahu cara bekerja dengan jaringan format TensorFlowLite, yang berarti kita perlu melakukan beberapa manipulasi dengan jaringan sumber. Misalkan Anda memiliki jaringan yang sudah terlatih pada kerangka Keras atau Tensorflow. Anda harus menyimpan kisi dalam format pb.

Mari kita mulai dengan kasing ketika Anda menulis di Tensorflow, maka semuanya menjadi sedikit lebih mudah.

saver = tf.train.Saver() tf.train.write_graph(session.graph_def, path_to_folder, "net.pb", False) tf.train.write_graph(session.graph_def, path_to_folder, "net.pbtxt", True) saver.save(session,path_to_folder+"model.ckpt") 

Jika Anda menulis dalam Keras, Anda perlu membuat objek sesi baru, menyimpan tautan ke sana di awal file tempat Anda melatih jaringan, dan meneruskannya ke fungsi set_session

 import keras.backend as K session = K.get_session() K.set_session(session) 

Nah, Anda menyimpan jaringan, sekarang Anda perlu mengubahnya menjadi format tflite. Untuk melakukan ini, kita perlu menjalankan dua skrip kecil, yang pertama "membekukan" jaringan, yang kedua akan menerjemahkan ke dalam format yang diinginkan. Inti dari "pembekuan" adalah bahwa tf tidak menyimpan bobot lapisan dalam file pb yang disimpan, tetapi menyimpannya di pos pemeriksaan khusus. Untuk konversi selanjutnya ke tflite, Anda memerlukan semua informasi tentang jaringan saraf dalam satu file.

 freeze_graph --input_binary=false --input_graph=net.pbtxt --output_node_names=result/Softmax --output_graph=frozen_graph.pb --input_checkpoint=model.ckpt 

Perhatikan bahwa Anda perlu tahu nama tensor output. Dalam tensorflow Anda dapat mengaturnya sendiri, jika menggunakan Keras - setel nama dalam konstruktor layer

 model.add(Dense(10,activation="softmax",name="result")) 

Dalam hal ini, nama tensor biasanya terlihat seperti "result / Softmax"

Jika dalam kasus Anda tidak demikian, Anda dapat menemukan namanya sebagai berikut

 [print(n.name) for n in session.graph.as_graph_def().node] 

Masih menjalankan skrip kedua

 toco --graph_def_file=frozen-graph.pb --output_file=model.tflite --output_format=TFLITE --inference_type=FLOAT --input_arrays=input_input --output_arrays=result/Softmax --input_shapes=1,784 

Hore! Sekarang Anda memiliki model TensorFlowLite di folder Anda, terserah Anda untuk mengintegrasikannya dengan benar ke dalam aplikasi Android Anda. Anda dapat melakukan ini dengan Firebase ML Kit yang bermodel baru, tetapi ada cara lain, tentang hal itu sedikit kemudian. Tambahkan dependensi ke file gradle kami

 dependencies { // ... implementation 'com.google.firebase:firebase-ml-model-interpreter:16.2.0' } 

Sekarang Anda perlu memutuskan apakah Anda akan menyimpan model di suatu tempat di server Anda, atau mengirim dengan aplikasi.

Pertimbangkan kasus pertama: model di server. Pertama-tama, jangan lupa untuk menambahkan manifes

 <uses-permission android:name="android.permission.INTERNET" /> 

  //      ,   /  FirebaseModelDownloadConditions.Builder conditionsBuilder = new FirebaseModelDownloadConditions.Builder().requireWifi(); if (Build.VERSION.SDK_INT >= Build.VERSION_CODES.N) { conditionsBuilder = conditionsBuilder .requireCharging(); } FirebaseModelDownloadConditions conditions = conditionsBuilder.build(); //   FirebaseCloudModelSource ,   (    ,  //   Firebase) FirebaseCloudModelSource cloudSource = new FirebaseCloudModelSource.Builder("my_cloud_model") .enableModelUpdates(true) .setInitialDownloadConditions(conditions) .setUpdatesDownloadConditions(conditions) .build(); FirebaseModelManager.getInstance().registerCloudModelSource(cloudSource); 

Jika Anda menggunakan model yang termasuk dalam aplikasi secara lokal, jangan lupa untuk menambahkan entri berikut ke file build.gradle sehingga file model tidak dikompresi

 android { // ... aaptOptions { noCompress "tflite" } } 

Setelah itu, dengan analogi dengan model di awan, neuron lokal kita perlu didaftarkan.

 FirebaseLocalModelSource localSource = new FirebaseLocalModelSource.Builder("my_local_model") .setAssetFilePath("mymodel.tflite") .build(); FirebaseModelManager.getInstance().registerLocalModelSource(localSource); 

Kode di atas mengasumsikan model Anda ada di folder aset, jika tidak, sebagai gantinya

  .setAssetFilePath("mymodel.tflite") 

gunakan

  .seFilePath(filePath) 

Kemudian kita membuat objek baru FirebaseModelOptions dan FirebaseModelInterpreter

 FirebaseModelOptions options = new FirebaseModelOptions.Builder() .setCloudModelName("my_cloud_model") .setLocalModelName("my_local_model") .build(); FirebaseModelInterpreter firebaseInterpreter = FirebaseModelInterpreter.getInstance(options); 

Anda dapat menggunakan model lokal dan berbasis server secara bersamaan. Dalam hal ini, cloud akan digunakan secara default, jika tersedia, jika tidak lokal.

Hampir semuanya, tetap membuat array untuk input / output data, dan jalankan!

 FirebaseModelInputOutputOptions inputOutputOptions = new FirebaseModelInputOutputOptions.Builder() .setInputFormat(0, FirebaseModelDataType.BYTE, new int[]{1, 640, 480, 3}) .setOutputFormat(0, FirebaseModelDataType.FLOAT32, new int[]{1, 784}) .build(); byte[][][][] input = new byte[1][640][480][3]; input = getYourInputData(); FirebaseModelInputs inputs = new FirebaseModelInputs.Builder() .add(input) // add() as many input arrays as your model requires .build(); Task<FirebaseModelOutputs> result = firebaseInterpreter.run(inputs, inputOutputOptions) .addOnSuccessListener( new OnSuccessListener<FirebaseModelOutputs>() { @Override public void onSuccess(FirebaseModelOutputs result) { // ... } }) .addOnFailureListener( new OnFailureListener() { @Override public void onFailure(@NonNull Exception e) { // Task failed with an exception // ... } }); float[][] output = result.<float[][]>getOutput(0); float[] probabilities = output[0]; 

Jika karena alasan tertentu Anda tidak ingin menggunakan Firebase, ada cara lain, memanggil juru bahasa tflite dan mengumpankan data secara langsung.

Tambahkan baris untuk membangun / gradle

  implementation 'org.tensorflow:tensorflow-lite:+' 

Buat juru bahasa dan array

  Interpreter tflite = new Interpreter(loadModelFile(getContext(), "model.tflite")); //     inputs tflite.run(inputs,outputs) 

Kode dalam kasus ini jauh lebih sedikit, seperti yang Anda lihat.

Itu semua yang Anda butuhkan untuk menggunakan jaringan saraf Anda di android.

Tautan yang bermanfaat:

Off dok oleh ML Kit
Tensorflow Lite

Source: https://habr.com/ru/post/id422041/


All Articles