Transfer Learning: cara cepat melatih jaringan saraf pada data Anda

Pembelajaran mesin menjadi lebih mudah diakses, ada lebih banyak peluang untuk menerapkan teknologi ini menggunakan “komponen yang tidak tersedia”. Sebagai contoh, Transfer Learning memungkinkan Anda untuk menggunakan pengalaman yang diperoleh dalam menyelesaikan satu masalah untuk memecahkan masalah lain yang serupa. Jaringan saraf pertama kali dilatih tentang sejumlah besar data, kemudian pada target yang ditetapkan.

Pengakuan makanan

Pada artikel ini saya akan memberi tahu Anda cara menggunakan metode Transfer Learning menggunakan contoh pengenalan gambar dengan makanan. Saya akan berbicara tentang alat pembelajaran mesin lainnya di lokakarya Machine Learning dan Neural Networks for Developers .

Jika kita dihadapkan pada tugas pengenalan gambar, Anda dapat menggunakan layanan yang sudah jadi. Namun, jika Anda perlu melatih model pada set data Anda sendiri, Anda harus melakukannya sendiri.

Untuk tugas-tugas umum seperti klasifikasi gambar, Anda dapat menggunakan arsitektur yang sudah jadi (AlexNet, VGG, Inception, ResNet, dll.) Dan melatih jaringan saraf pada data Anda. Implementasi jaringan seperti itu menggunakan berbagai kerangka kerja yang sudah ada, jadi pada tahap ini Anda dapat menggunakan salah satunya sebagai kotak hitam, tanpa mempelajari prinsip operasinya secara mendalam.

Namun, jaringan saraf yang dalam menuntut data dalam jumlah besar untuk konvergensi pembelajaran. Dan seringkali dalam tugas khusus kita tidak ada cukup data untuk melatih semua lapisan jaringan saraf dengan benar. Transfer Learning memecahkan masalah ini.

Transfer Pembelajaran untuk Klasifikasi Gambar


Jaringan saraf yang digunakan untuk klasifikasi biasanya mengandung neuron keluaran N di lapisan terakhir, di mana N adalah jumlah kelas. Seperti vektor keluaran diperlakukan sebagai satu set probabilitas milik kelas. Dalam tugas kami mengenali gambar makanan, jumlah kelas mungkin berbeda dari yang ada di dataset asli. Dalam hal ini, kita harus benar-benar membuang lapisan terakhir ini dan memasukkan yang baru, dengan jumlah neuron output yang tepat

Transfer belajar

Seringkali di akhir jaringan klasifikasi, lapisan yang terhubung sepenuhnya digunakan. Karena kami mengganti lapisan ini, menggunakan bobot yang sudah dilatih sebelumnya untuk itu tidak akan berfungsi. Anda harus melatihnya dari awal, menginisialisasi bobotnya dengan nilai acak. Kami memuat bobot untuk semua lapisan lainnya dari snapshot yang sudah dilatih sebelumnya.

Ada berbagai strategi untuk pelatihan model selanjutnya. Kami akan menggunakan yang berikut ini: kami akan melatih seluruh jaringan dari ujung ke ujung ( ujung ke ujung ), dan kami tidak akan memperbaiki bobot pra-pelatihan untuk memungkinkan mereka menyesuaikan sedikit dan menyesuaikan dengan data kami. Proses ini disebut fine-tuning .

Komponen struktural


Untuk mengatasi masalah tersebut, kita membutuhkan komponen berikut:

  1. Deskripsi model jaringan saraf
  2. Pipa belajar
  3. Pipa interferensi
  4. Bobot pra-terlatih untuk model ini
  5. Data untuk pelatihan dan validasi

Komponen

Dalam contoh kita, saya akan mengambil komponen (1), (2) dan (3) dari repositori saya sendiri , yang berisi kode paling ringan - Anda dapat dengan mudah mengetahuinya jika Anda mau. Contoh kami akan diimplementasikan pada kerangka kerja TensorFlow yang populer. Bobot pra-terlatih (4) yang cocok untuk kerangka kerja yang dipilih dapat ditemukan jika sesuai dengan salah satu arsitektur klasik. Sebagai dataset (5) untuk demonstrasi saya akan mengambil Food-101 .

Model


Sebagai model, kami menggunakan jaringan saraf VGG klasik (lebih tepatnya, VGG19 ). Meskipun ada beberapa kelemahan, model ini menunjukkan kualitas yang cukup tinggi. Selain itu, mudah untuk dianalisis. Pada TensorFlow Slim, deskripsi model terlihat cukup ringkas:

 import tensorflow as tf import tensorflow.contrib.slim as slim def vgg_19(inputs, num_classes, is_training, scope='vgg_19', weight_decay=0.0005): with slim.arg_scope([slim.conv2d], activation_fn=tf.nn.relu, weights_regularizer=slim.l2_regularizer(weight_decay), biases_initializer=tf.zeros_initializer(), padding='SAME'): with tf.variable_scope(scope, 'vgg_19', [inputs]): net = slim.repeat(inputs, 2, slim.conv2d, 64, [3, 3], scope='conv1') net = slim.max_pool2d(net, [2, 2], scope='pool1') net = slim.repeat(net, 2, slim.conv2d, 128, [3, 3], scope='conv2') net = slim.max_pool2d(net, [2, 2], scope='pool2') net = slim.repeat(net, 4, slim.conv2d, 256, [3, 3], scope='conv3') net = slim.max_pool2d(net, [2, 2], scope='pool3') net = slim.repeat(net, 4, slim.conv2d, 512, [3, 3], scope='conv4') net = slim.max_pool2d(net, [2, 2], scope='pool4') net = slim.repeat(net, 4, slim.conv2d, 512, [3, 3], scope='conv5') net = slim.max_pool2d(net, [2, 2], scope='pool5') # Use conv2d instead of fully_connected layers net = slim.conv2d(net, 4096, [7, 7], padding='VALID', scope='fc6') net = slim.dropout(net, 0.5, is_training=is_training, scope='drop6') net = slim.conv2d(net, 4096, [1, 1], scope='fc7') net = slim.dropout(net, 0.5, is_training=is_training, scope='drop7') net = slim.conv2d(net, num_classes, [1, 1], scope='fc8', activation_fn=None) net = tf.squeeze(net, [1, 2], name='fc8/squeezed') return net 

Bobot untuk VGG19, dilatih di ImageNet dan kompatibel dengan TensorFlow, diunduh dari repositori di GitHub dari bagian Model Pra-terlatih .

 mkdir data && cd data wget http://download.tensorflow.org/models/vgg_19_2016_08_28.tar.gz tar -xzf vgg_19_2016_08_28.tar.gz 

Datacet


Sebagai sampel pelatihan dan validasi, kami akan menggunakan dataset Food-101 publik, yang berisi lebih dari 100 ribu gambar makanan, dibagi menjadi 101 kategori.

Dataset Makanan-101

Unduh dan buka paket dataset:

 cd data wget http://data.vision.ee.ethz.ch/cvl/food-101.tar.gz tar -xzf food-101.tar.gz 

Pipa data dalam pelatihan kami dirancang sehingga dari dataset kami perlu menguraikan berikut ini:

  1. Daftar kelas (kategori)
  2. Tutorial: daftar jalur ke gambar dan daftar jawaban yang benar
  3. Kumpulan validasi: daftar jalur ke gambar dan daftar jawaban yang benar

Jika dataset Anda, maka untuk kereta dan validasi Anda perlu memecahkan set sendiri. Food-101 sudah memiliki partisi seperti itu, dan informasi ini disimpan di direktori meta .

 DATASET_ROOT = 'data/food-101/' train_data, val_data, classes = data.food101(DATASET_ROOT) num_classes = len(classes) 

Semua fungsi tambahan yang bertanggung jawab untuk pemrosesan data dipindahkan ke file data.py terpisah:

data.py
 from os.path import join as opj import tensorflow as tf def parse_ds_subset(img_root, list_fpath, classes): ''' Parse a meta file with image paths and labels -> img_root: path to the root of image folders -> list_fpath: path to the file with the list (eg train.txt) -> classes: list of class names <- (list_of_img_paths, integer_labels) ''' fpaths = [] labels = [] with open(list_fpath, 'r') as f: for line in f: class_name, image_id = line.strip().split('/') fpaths.append(opj(img_root, class_name, image_id+'.jpg')) labels.append(classes.index(class_name)) return fpaths, labels def food101(dataset_root): ''' Get lists of train and validation examples for Food-101 dataset -> dataset_root: root of the Food-101 dataset <- ((train_fpaths, train_labels), (val_fpaths, val_labels), classes) ''' img_root = opj(dataset_root, 'images') train_list_fpath = opj(dataset_root, 'meta', 'train.txt') test_list_fpath = opj(dataset_root, 'meta', 'test.txt') classes_list_fpath = opj(dataset_root, 'meta', 'classes.txt') with open(classes_list_fpath, 'r') as f: classes = [line.strip() for line in f] train_data = parse_ds_subset(img_root, train_list_fpath, classes) val_data = parse_ds_subset(img_root, test_list_fpath, classes) return train_data, val_data, classes def imread_and_crop(fpath, inp_size, margin=0, random_crop=False): ''' Construct TF graph for image preparation: Read the file, crop and resize -> fpath: path to the JPEG image file (TF node) -> inp_size: size of the network input (eg 224) -> margin: cropping margin -> random_crop: perform random crop or central crop <- prepared image (TF node) ''' data = tf.read_file(fpath) img = tf.image.decode_jpeg(data, channels=3) img = tf.image.convert_image_dtype(img, dtype=tf.float32) shape = tf.shape(img) crop_size = tf.minimum(shape[0], shape[1]) - 2 * margin if random_crop: img = tf.random_crop(img, (crop_size, crop_size, 3)) else: # central crop ho = (shape[0] - crop_size) // 2 wo = (shape[0] - crop_size) // 2 img = img[ho:ho+crop_size, wo:wo+crop_size, :] img = tf.image.resize_images(img, (inp_size, inp_size), method=tf.image.ResizeMethod.AREA) return img def train_dataset(data, batch_size, epochs, inp_size, margin): ''' Prepare training data pipeline -> data: (list_of_img_paths, integer_labels) -> batch_size: training batch size -> epochs: number of training epochs -> inp_size: size of the network input (eg 224) -> margin: cropping margin <- (dataset, number_of_train_iterations) ''' num_examples = len(data[0]) iters = (epochs * num_examples) // batch_size def fpath_to_image(fpath, label): img = imread_and_crop(fpath, inp_size, margin, random_crop=True) return img, label dataset = tf.data.Dataset.from_tensor_slices(data) dataset = dataset.shuffle(buffer_size=num_examples) dataset = dataset.map(fpath_to_image) dataset = dataset.repeat(epochs) dataset = dataset.batch(batch_size, drop_remainder=True) return dataset, iters def val_dataset(data, batch_size, inp_size): ''' Prepare validation data pipeline -> data: (list_of_img_paths, integer_labels) -> batch_size: validation batch size -> inp_size: size of the network input (eg 224) <- (dataset, number_of_val_iterations) ''' num_examples = len(data[0]) iters = num_examples // batch_size def fpath_to_image(fpath, label): img = imread_and_crop(fpath, inp_size, 0, random_crop=False) return img, label dataset = tf.data.Dataset.from_tensor_slices(data) dataset = dataset.map(fpath_to_image) dataset = dataset.batch(batch_size, drop_remainder=True) return dataset, iters 


Pelatihan model


Kode pelatihan model terdiri dari langkah-langkah berikut:

  1. Membangun jalur data kereta / validasi
  2. Membangun kereta / grafik validasi (jaringan)
  3. Lampiran fungsi klasifikasi kerugian ( cross entropy loss ) di atas grafik kereta
  4. Kode diperlukan untuk menghitung keakuratan prediksi pada sampel validasi selama pelatihan
  5. Logika untuk memuat skala pra-terlatih dari snapshot
  6. Membuat berbagai struktur untuk pembelajaran
  7. Siklus pembelajaran itu sendiri (optimasi berulang)

Lapisan terakhir grafik dikonstruksi dengan jumlah neuron yang diperlukan dan dikeluarkan dari daftar parameter yang dimuat dari snapshot pra-terlatih.

Kode Pelatihan Model
 import numpy as np import tensorflow as tf import tensorflow.contrib.slim as slim tf.logging.set_verbosity(tf.logging.INFO) import model import data ########################################################### ### Settings ########################################################### INPUT_SIZE = 224 RANDOM_CROP_MARGIN = 10 TRAIN_EPOCHS = 20 TRAIN_BATCH_SIZE = 64 VAL_BATCH_SIZE = 128 LR_START = 0.001 LR_END = LR_START / 1e4 MOMENTUM = 0.9 VGG_PRETRAINED_CKPT = 'data/vgg_19.ckpt' CHECKPOINT_DIR = 'checkpoints/vgg19_food' LOG_LOSS_EVERY = 10 CALC_ACC_EVERY = 500 ########################################################### ### Build training and validation data pipelines ########################################################### train_ds, train_iters = data.train_dataset(train_data, TRAIN_BATCH_SIZE, TRAIN_EPOCHS, INPUT_SIZE, RANDOM_CROP_MARGIN) train_ds_iterator = train_ds.make_one_shot_iterator() train_x, train_y = train_ds_iterator.get_next() val_ds, val_iters = data.val_dataset(val_data, VAL_BATCH_SIZE, INPUT_SIZE) val_ds_iterator = val_ds.make_initializable_iterator() val_x, val_y = val_ds_iterator.get_next() ########################################################### ### Construct training and validation graphs ########################################################### with tf.variable_scope('', reuse=tf.AUTO_REUSE): train_logits = model.vgg_19(train_x, num_classes, is_training=True) val_logits = model.vgg_19(val_x, num_classes, is_training=False) ########################################################### ### Construct training loss ########################################################### loss = tf.losses.sparse_softmax_cross_entropy( labels=train_y, logits=train_logits) tf.summary.scalar('loss', loss) ########################################################### ### Construct validation accuracy ### and related functions ########################################################### def calc_accuracy(sess, val_logits, val_y, val_iters): acc_total = 0.0 acc_denom = 0 for i in range(val_iters): logits, y = sess.run((val_logits, val_y)) y_pred = np.argmax(logits, axis=1) correct = np.count_nonzero(y == y_pred) acc_denom += y_pred.shape[0] acc_total += float(correct) tf.logging.info('Validating batch [{} / {}] correct = {}'.format( i, val_iters, correct)) acc_total /= acc_denom return acc_total def accuracy_summary(sess, acc_value, iteration): acc_summary = tf.Summary() acc_summary.value.add(tag="accuracy", simple_value=acc_value) sess._hooks[1]._summary_writer.add_summary(acc_summary, iteration) ########################################################### ### Define set of VGG variables to restore ### Create the Restorer ### Define init callback (used by monitored session) ########################################################### vars_to_restore = tf.contrib.framework.get_variables_to_restore( exclude=['vgg_19/fc8']) vgg_restorer = tf.train.Saver(vars_to_restore) def init_fn(scaffold, sess): vgg_restorer.restore(sess, VGG_PRETRAINED_CKPT) ########################################################### ### Create various training structures ########################################################### global_step = tf.train.get_or_create_global_step() lr = tf.train.polynomial_decay(LR_START, global_step, train_iters, LR_END) tf.summary.scalar('learning_rate', lr) optimizer = tf.train.MomentumOptimizer(learning_rate=lr, momentum=MOMENTUM) training_op = slim.learning.create_train_op( loss, optimizer, global_step=global_step) scaffold = tf.train.Scaffold(init_fn=init_fn) ########################################################### ### Create monitored session ### Run training loop ########################################################### with tf.train.MonitoredTrainingSession(checkpoint_dir=CHECKPOINT_DIR, save_checkpoint_secs=600, save_summaries_steps=30, scaffold=scaffold) as sess: start_iter = sess.run(global_step) for iteration in range(start_iter, train_iters): # Gradient Descent loss_value = sess.run(training_op) # Loss logging if iteration % LOG_LOSS_EVERY == 0: tf.logging.info('[{} / {}] Loss = {}'.format( iteration, train_iters, loss_value)) # Accuracy logging if iteration % CALC_ACC_EVERY == 0: sess.run(val_ds_iterator.initializer) acc_value = calc_accuracy(sess, val_logits, val_y, val_iters) accuracy_summary(sess, acc_value, iteration) tf.logging.info('[{} / {}] Validation accuracy = {}'.format( iteration, train_iters, acc_value)) 


Setelah memulai pelatihan, Anda dapat melihat progresnya menggunakan utilitas TensorBoard, yang dilengkapi dengan TensorFlow dan berfungsi untuk memvisualisasikan berbagai metrik dan parameter lainnya.

 tensorboard --logdir checkpoints/ 

Pada akhir pelatihan di TensorBoard, kami melihat gambaran yang hampir sempurna: penurunan kehilangan kereta dan peningkatan Akurasi Validasi

Kehilangan dan akurasi TensorBoard

Hasilnya, kami mendapatkan snapshot yang disimpan di checkpoints/vgg19_food , yang akan kami gunakan selama pengujian model kami ( inferensi ).

Pengujian model


Sekarang uji model kami. Untuk melakukan ini:

  1. Kami membuat grafik baru yang dirancang khusus untuk inferensi ( is_training=False )
  2. Memuat bobot terlatih dari foto
  3. Unduh dan preprocess gambar uji input.
  4. Mari kita mengarahkan gambar melalui jaringan saraf dan mendapatkan prediksi

inference.py
 import sys import numpy as np import imageio from skimage.transform import resize import tensorflow as tf import model ########################################################### ### Settings ########################################################### CLASSES_FPATH = 'data/food-101/meta/labels.txt' INP_SIZE = 224 # Input will be cropped and resized CHECKPOINT_DIR = 'checkpoints/vgg19_food' IMG_FPATH = 'data/food-101/images/bruschetta/3564471.jpg' ########################################################### ### Get all class names ########################################################### with open(CLASSES_FPATH, 'r') as f: classes = [line.strip() for line in f] num_classes = len(classes) ########################################################### ### Construct inference graph ########################################################### x = tf.placeholder(tf.float32, (1, INP_SIZE, INP_SIZE, 3), name='inputs') logits = model.vgg_19(x, num_classes, is_training=False) ########################################################### ### Create TF session and restore from a snapshot ########################################################### sess = tf.Session() snapshot_fpath = tf.train.latest_checkpoint(CHECKPOINT_DIR) restorer = tf.train.Saver() restorer.restore(sess, snapshot_fpath) ########################################################### ### Load and prepare input image ########################################################### def crop_and_resize(img, input_size): crop_size = min(img.shape[0], img.shape[1]) ho = (img.shape[0] - crop_size) // 2 wo = (img.shape[0] - crop_size) // 2 img = img[ho:ho+crop_size, wo:wo+crop_size, :] img = resize(img, (input_size, input_size), order=3, mode='reflect', anti_aliasing=True, preserve_range=True) return img img = imageio.imread(IMG_FPATH) img = img.astype(np.float32) img = crop_and_resize(img, INP_SIZE) img = img[None, ...] ########################################################### ### Run inference ########################################################### out = sess.run(logits, feed_dict={x:img}) pred_class = classes[np.argmax(out)] print('Input: {}'.format(IMG_FPATH)) print('Prediction: {}'.format(pred_class)) 


Kesimpulan

Semua kode, termasuk sumber daya untuk membangun dan menjalankan wadah Docker dengan semua versi perpustakaan yang diperlukan, ada di repositori ini - pada saat membaca artikel, kode dalam repositori mungkin memiliki pembaruan.

Pada lokakarya "Pembelajaran Mesin dan Jaringan Saraf Tiruan untuk Pengembang" Saya akan menganalisis tugas pembelajaran mesin lainnya, dan siswa akan mempresentasikan proyek mereka pada akhir sesi intensif.

Source: https://habr.com/ru/post/id428255/


All Articles