Die Stilübertragung ist der Prozess der Konvertierung des Stils der Quelle in den Stil des ausgewählten Bildes und basiert auf dem zuvor trainierten Convolution-Netzwerktyp (CNN). So viel hängt von der Wahl dieses trainierten Netzwerks ab. Glücklicherweise gibt es solche Netzwerke und es gibt eine große Auswahl, aber VGG-16 wird hier verwendet.
Zuerst müssen Sie die erforderlichen Bibliotheken verbinden
Code für die Bibliotheksdeklarationimport time import torch from torch.autograd import Variable import torch.nn as nn import torch.nn.functional as F from torch import optim import torchvision from torchvision import transforms from io import BytesIO from PIL import Image from collections import OrderedDict from google.colab import files
Dann müssen Sie die vorab trainierte Netzwerkklasse VGG-16 deklarieren
Klassencode VGG-16 class VGG16(nn.Module): def __init__(self, pool='max'): super(VGG, self).__init__() self.conv1_1 = nn.Conv2d(3, 64, kernel_size=3, padding=1) self.conv1_2 = nn.Conv2d(64, 64, kernel_size=3, padding=1) self.conv2_1 = nn.Conv2d(64, 128, kernel_size=3, padding=1) self.conv2_2 = nn.Conv2d(128, 128, kernel_size=3, padding=1) self.conv3_1 = nn.Conv2d(128, 256, kernel_size=3, padding=1) self.conv3_2 = nn.Conv2d(256, 256, kernel_size=3, padding=1) self.conv3_3 = nn.Conv2d(256, 256, kernel_size=3, padding=1) self.conv3_4 = nn.Conv2d(256, 256, kernel_size=3, padding=1) self.conv4_1 = nn.Conv2d(256, 512, kernel_size=3, padding=1) self.conv4_2 = nn.Conv2d(512, 512, kernel_size=3, padding=1) self.conv4_3 = nn.Conv2d(512, 512, kernel_size=3, padding=1) self.conv4_4 = nn.Conv2d(512, 512, kernel_size=3, padding=1) self.conv5_1 = nn.Conv2d(512, 512, kernel_size=3, padding=1) self.conv5_2 = nn.Conv2d(512, 512, kernel_size=3, padding=1) self.conv5_3 = nn.Conv2d(512, 512, kernel_size=3, padding=1) self.conv5_4 = nn.Conv2d(512, 512, kernel_size=3, padding=1) if pool == 'max': self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2) self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2) self.pool3 = nn.MaxPool2d(kernel_size=2, stride=2) self.pool4 = nn.MaxPool2d(kernel_size=2, stride=2) self.pool5 = nn.MaxPool2d(kernel_size=2, stride=2) elif pool == 'avg': self.pool1 = nn.AvgPool2d(kernel_size=2, stride=2) self.pool2 = nn.AvgPool2d(kernel_size=2, stride=2) self.pool3 = nn.AvgPool2d(kernel_size=2, stride=2) self.pool4 = nn.AvgPool2d(kernel_size=2, stride=2) self.pool5 = nn.AvgPool2d(kernel_size=2, stride=2) def forward(self, x, layers): out = {} out['relu1_1'] = F.relu(self.conv1_1(x)) out['relu1_2'] = F.relu(self.conv1_2(out['relu1_1'])) out['pool1'] = self.pool1(out['relu1_2']) out['relu2_1'] = F.relu(self.conv2_1(out['pool1'])) out['relu2_2'] = F.relu(self.conv2_2(out['relu2_1'])) out['pool2'] = self.pool2(out['relu2_2']) out['relu3_1'] = F.relu(self.conv3_1(out['pool2'])) out['relu3_2'] = F.relu(self.conv3_2(out['relu3_1'])) out['relu3_3'] = F.relu(self.conv3_3(out['relu3_2'])) out['relu3_4'] = F.relu(self.conv3_4(out['relu3_3'])) out['pool3'] = self.pool3(out['relu3_4']) out['relu4_1'] = F.relu(self.conv4_1(out['pool3'])) out['relu4_2'] = F.relu(self.conv4_2(out['relu4_1'])) out['relu4_3'] = F.relu(self.conv4_3(out['relu4_2'])) out['relu4_4'] = F.relu(self.conv4_4(out['relu4_3'])) out['pool4'] = self.pool4(out['relu4_4']) out['relu5_1'] = F.relu(self.conv5_1(out['pool4'])) out['relu5_2'] = F.relu(self.conv5_2(out['relu5_1'])) out['relu5_3'] = F.relu(self.conv5_3(out['relu5_2'])) out['relu5_4'] = F.relu(self.conv5_4(out['relu5_3'])) out['pool5'] = self.pool5(out['relu5_4']) return [out[key] for key in layers]
Als nächstes müssen Sie die VGG-16-Gewichte herunterladen und laden, nachdem Sie sie nach Möglichkeit zuvor auf die Grafikkarte übertragen haben
vgg = VGG16() vgg.load_state_dict(torch.load('vgg_conv.pth')) for param in vgg.parameters(): param.requires_grad = False if torch.cuda.is_available(): vgg.cuda()
Dabei ist vgg_conv.pth der Name der Netzwerkgewichtsdatei.
In diesem Fall muss das Training der Parameter im Netzwerk deaktiviert werden, andernfalls können Sie die geladenen Gewichte verderben, die länger als einen Tag trainiert wurden.
Anschließend werden die Funktionen zum Konvertieren von Eingabebildern angekündigt, um sie in die Form von Bildern zu bringen, auf denen das VGG-16-Netzwerk trainiert wurde
Code für die Eingabebildkonvertierungsfunktionen SIZE_IMAGE = 512 to_mean_tensor = transforms.Compose([transforms.Resize(SIZE_IMAGE), transforms.ToTensor(), transforms.Lambda(lambda x: x[torch.LongTensor([2,1,0])]), transforms.Normalize(mean=[0.40760392, 0.45795686, 0.48501961], std=[1,1,1]), transforms.Lambda(lambda x: x.mul_(255)), ]) to_unmean_tensor = transforms.Compose([transforms.Lambda(lambda x: x.div_(255)), transforms.Normalize(mean=[-0.40760392, -0.45795686, -0.48501961], std=[1,1,1]), transforms.Lambda(lambda x: x[torch.LongTensor([2,1,0])]), ]) to_image = transforms.Compose([transforms.ToPILImage()]) normalize_image = lambda t: to_image(torch.clamp(to_unmean_tensor(t), min=0, max=1))
to_mean_tensor - direkte Konvertierung
normalize_image - inverse Transformation
Als nächstes werden Gram-Matrixklassen und Verlustfunktionen für die Gram-Matrix angekündigt
class GramMatrix(nn.Module): def forward(self, input): b,c,h,w = input.size() F = input.view(b, c, h*w) G = torch.bmm(F, F.transpose(1,2)) G.div_(h*w) return G class GramMSELoss(nn.Module): def forward(self, input, target): out = nn.MSELoss()(GramMatrix()(input), target) return out
Die Gram-Matrix dient dazu, den räumlichen Bezug von Stildetails zu eliminieren.
Dann folgt das Laden und Konvertieren der Quell- und Stilbilder.
imgs = [style_img, content_img] imgs_torch = [to_mean_tensor(img) for img in imgs] if torch.cuda.is_available(): imgs_torch = [Variable(img.unsqueeze(0).cuda()) for img in imgs_torch] else: imgs_torch = [Variable(img.unsqueeze(0)) for img in imgs_torch] style_image, content_image = imgs_torch opt_img = Variable(content_image.data.clone(), requires_grad=True)
Wobei style_img und content_img Eingabebilder sind, die in Tensoren konvertiert und wenn möglich auf die Grafikkarte übertragen werden, und opt_img das Ergebnis der Stilübertragung enthält, während das ursprüngliche Bild als erstes verwendet wird.
Als nächstes werden Ebenen ausgewählt, Gewichte festgelegt und Verlustfunktionen initialisiert
Gewichts- und Verlustcode style_layers = ['relu1_1','relu2_1','relu3_1','relu4_1', 'relu5_1'] content_layers = ['relu4_2'] loss_layers = style_layers + content_layers losses = [GramMSELoss()] * len(style_layers) + [nn.MSELoss()] * len(content_layers) if torch.cuda.is_available(): losses = [loss.cuda() for loss in losses] style_weights = [1e3/n**2 for n in [64,128,256,512,512]] content_weights = [1e0] weights = style_weights + content_weights style_targets = [GramMatrix()(A).detach() for A in vgg(style_image, style_layers)] content_targets = [A.detach() for A in vgg(content_image, content_layers)] targets = style_targets + content_targets
Und der letzte Schritt ist der Prozess der Stilübertragung
epochs = 300 opt = optim.LBFGS([opt_img]) def step_opt(): opt.zero_grad() out_layers = vgg(opt_img, loss_layers) layer_losses = [] for j, out in enumerate(out_layers): layer_losses.append(weights[j] * losses[j](out, targets[j])) loss = sum(layer_losses) loss.backward() return loss for i in range(0, epochs+1): loss = opt.step(step_opt)
Abschließend können Sie einige Beispiele hinzufügen:














