Transfer gaya

Transfer gaya adalah proses mengubah gaya sumber ke gaya gambar yang dipilih dan bergantung pada Jenis Jaringan Konvolusi (CNN), yang sudah dilatih sebelumnya, sehingga banyak yang akan tergantung pada pilihan jaringan yang dilatih ini. Untungnya, ada jaringan seperti itu dan ada banyak untuk dipilih, tetapi VGG-16 akan digunakan di sini.

Pertama, Anda perlu menghubungkan perpustakaan yang diperlukan

Kode deklarasi perpustakaan
import time import torch from torch.autograd import Variable import torch.nn as nn import torch.nn.functional as F from torch import optim import torchvision from torchvision import transforms from io import BytesIO from PIL import Image from collections import OrderedDict from google.colab import files 

Maka Anda perlu mendeklarasikan kelas jaringan pra-terlatih VGG-16

Kode Kelas VGG-16
 class VGG16(nn.Module): def __init__(self, pool='max'): super(VGG, self).__init__() self.conv1_1 = nn.Conv2d(3, 64, kernel_size=3, padding=1) self.conv1_2 = nn.Conv2d(64, 64, kernel_size=3, padding=1) self.conv2_1 = nn.Conv2d(64, 128, kernel_size=3, padding=1) self.conv2_2 = nn.Conv2d(128, 128, kernel_size=3, padding=1) self.conv3_1 = nn.Conv2d(128, 256, kernel_size=3, padding=1) self.conv3_2 = nn.Conv2d(256, 256, kernel_size=3, padding=1) self.conv3_3 = nn.Conv2d(256, 256, kernel_size=3, padding=1) self.conv3_4 = nn.Conv2d(256, 256, kernel_size=3, padding=1) self.conv4_1 = nn.Conv2d(256, 512, kernel_size=3, padding=1) self.conv4_2 = nn.Conv2d(512, 512, kernel_size=3, padding=1) self.conv4_3 = nn.Conv2d(512, 512, kernel_size=3, padding=1) self.conv4_4 = nn.Conv2d(512, 512, kernel_size=3, padding=1) self.conv5_1 = nn.Conv2d(512, 512, kernel_size=3, padding=1) self.conv5_2 = nn.Conv2d(512, 512, kernel_size=3, padding=1) self.conv5_3 = nn.Conv2d(512, 512, kernel_size=3, padding=1) self.conv5_4 = nn.Conv2d(512, 512, kernel_size=3, padding=1) if pool == 'max': self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2) self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2) self.pool3 = nn.MaxPool2d(kernel_size=2, stride=2) self.pool4 = nn.MaxPool2d(kernel_size=2, stride=2) self.pool5 = nn.MaxPool2d(kernel_size=2, stride=2) elif pool == 'avg': self.pool1 = nn.AvgPool2d(kernel_size=2, stride=2) self.pool2 = nn.AvgPool2d(kernel_size=2, stride=2) self.pool3 = nn.AvgPool2d(kernel_size=2, stride=2) self.pool4 = nn.AvgPool2d(kernel_size=2, stride=2) self.pool5 = nn.AvgPool2d(kernel_size=2, stride=2) def forward(self, x, layers): out = {} out['relu1_1'] = F.relu(self.conv1_1(x)) out['relu1_2'] = F.relu(self.conv1_2(out['relu1_1'])) out['pool1'] = self.pool1(out['relu1_2']) out['relu2_1'] = F.relu(self.conv2_1(out['pool1'])) out['relu2_2'] = F.relu(self.conv2_2(out['relu2_1'])) out['pool2'] = self.pool2(out['relu2_2']) out['relu3_1'] = F.relu(self.conv3_1(out['pool2'])) out['relu3_2'] = F.relu(self.conv3_2(out['relu3_1'])) out['relu3_3'] = F.relu(self.conv3_3(out['relu3_2'])) out['relu3_4'] = F.relu(self.conv3_4(out['relu3_3'])) out['pool3'] = self.pool3(out['relu3_4']) out['relu4_1'] = F.relu(self.conv4_1(out['pool3'])) out['relu4_2'] = F.relu(self.conv4_2(out['relu4_1'])) out['relu4_3'] = F.relu(self.conv4_3(out['relu4_2'])) out['relu4_4'] = F.relu(self.conv4_4(out['relu4_3'])) out['pool4'] = self.pool4(out['relu4_4']) out['relu5_1'] = F.relu(self.conv5_1(out['pool4'])) out['relu5_2'] = F.relu(self.conv5_2(out['relu5_1'])) out['relu5_3'] = F.relu(self.conv5_3(out['relu5_2'])) out['relu5_4'] = F.relu(self.conv5_4(out['relu5_3'])) out['pool5'] = self.pool5(out['relu5_4']) return [out[key] for key in layers] 


Selanjutnya, Anda perlu mengunduh dan memuat bobot VGG-16, setelah sebelumnya mentransfernya ke kartu video, jika memungkinkan

 vgg = VGG16() vgg.load_state_dict(torch.load('vgg_conv.pth')) for param in vgg.parameters(): param.requires_grad = False if torch.cuda.is_available(): vgg.cuda() 

Di mana vgg_conv.pth adalah nama dari file bobot jaringan.

Dalam hal ini, perlu untuk menonaktifkan pelatihan parameter pada jaringan, jika tidak, Anda dapat merusak bobot yang dimuat yang telah dilatih selama lebih dari satu hari.

Setelah itu, fungsi konversi gambar input diumumkan untuk membawa mereka ke bentuk gambar di mana jaringan VGG-16 dilatih

Masukkan Kode Fungsi Konversi Gambar
 SIZE_IMAGE = 512 to_mean_tensor = transforms.Compose([transforms.Resize(SIZE_IMAGE), transforms.ToTensor(), transforms.Lambda(lambda x: x[torch.LongTensor([2,1,0])]), transforms.Normalize(mean=[0.40760392, 0.45795686, 0.48501961], std=[1,1,1]), transforms.Lambda(lambda x: x.mul_(255)), ]) to_unmean_tensor = transforms.Compose([transforms.Lambda(lambda x: x.div_(255)), transforms.Normalize(mean=[-0.40760392, -0.45795686, -0.48501961], std=[1,1,1]), transforms.Lambda(lambda x: x[torch.LongTensor([2,1,0])]), ]) to_image = transforms.Compose([transforms.ToPILImage()]) normalize_image = lambda t: to_image(torch.clamp(to_unmean_tensor(t), min=0, max=1)) 

to_mean_tensor - konversi langsung
normalize_image - transformasi terbalik

Selanjutnya, kelas matriks Gram dan fungsi kerugian untuk matriks Gram diumumkan

 class GramMatrix(nn.Module): def forward(self, input): b,c,h,w = input.size() F = input.view(b, c, h*w) G = torch.bmm(F, F.transpose(1,2)) G.div_(h*w) return G class GramMSELoss(nn.Module): def forward(self, input, target): out = nn.MSELoss()(GramMatrix()(input), target) return out 

Matriks Gram berfungsi untuk menghilangkan referensi spasial dari detail gaya.

Kemudian muncul proses memuat dan mengubah sumber dan gaya gambar.

 imgs = [style_img, content_img] imgs_torch = [to_mean_tensor(img) for img in imgs] if torch.cuda.is_available(): imgs_torch = [Variable(img.unsqueeze(0).cuda()) for img in imgs_torch] else: imgs_torch = [Variable(img.unsqueeze(0)) for img in imgs_torch] style_image, content_image = imgs_torch opt_img = Variable(content_image.data.clone(), requires_grad=True) 

Di mana style_img dan content_img adalah gambar input yang dikonversi ke tensor dan ditransfer ke kartu video jika memungkinkan, dan opt_img akan berisi hasil dari transfer gaya, sedangkan gambar awal diambil sebagai gambar awal.

Berikutnya adalah proses memilih lapisan, mengatur bobot dan menginisialisasi fungsi kerugian

Kode Berat dan Kehilangan
 style_layers = ['relu1_1','relu2_1','relu3_1','relu4_1', 'relu5_1'] content_layers = ['relu4_2'] loss_layers = style_layers + content_layers losses = [GramMSELoss()] * len(style_layers) + [nn.MSELoss()] * len(content_layers) if torch.cuda.is_available(): losses = [loss.cuda() for loss in losses] style_weights = [1e3/n**2 for n in [64,128,256,512,512]] content_weights = [1e0] weights = style_weights + content_weights style_targets = [GramMatrix()(A).detach() for A in vgg(style_image, style_layers)] content_targets = [A.detach() for A in vgg(content_image, content_layers)] targets = style_targets + content_targets 


Dan langkah terakhir adalah proses mentransfer gaya
 epochs = 300 opt = optim.LBFGS([opt_img]) def step_opt(): opt.zero_grad() out_layers = vgg(opt_img, loss_layers) layer_losses = [] for j, out in enumerate(out_layers): layer_losses.append(weights[j] * losses[j](out, targets[j])) loss = sum(layer_losses) loss.backward() return loss for i in range(0, epochs+1): loss = opt.step(step_opt) 

Sebagai kesimpulan, Anda dapat menambahkan beberapa contoh:





























Source: https://habr.com/ru/post/id453512/


All Articles