рд╣рд╛рдп, рд╣рдмреНрд░, рдЗрд╕ рд▓реЗрдЦ рдореЗрдВ рдореИрдВ рдЗрдЧреНрдирд╛рдЗрдЯ рд▓рд╛рдЗрдмреНрд░реЗрд░реА рдХреЗ рдмрд╛рд░реЗ рдореЗрдВ рдмрд╛рдд рдХрд░реВрдВрдЧрд╛, рдЬрд┐рд╕рдХреЗ рд╕рд╛рде рдЖрдк PyTorch рдврд╛рдВрдЪреЗ рдХрд╛ рдЙрдкрдпреЛрдЧ рдХрд░рдХреЗ рдЖрд╕рд╛рдиреА рд╕реЗ рддрдВрддреНрд░рд┐рдХрд╛ рдиреЗрдЯрд╡рд░реНрдХ рдХреЛ рдкреНрд░рд╢рд┐рдХреНрд╖рд┐рдд рдФрд░ рдкрд░реАрдХреНрд╖рдг рдХрд░ рд╕рдХрддреЗ рд╣реИрдВред
рдЗрдЧреНрдирд╛рдЗрдЯ рдХреЗ рд╕рд╛рде , рдЖрдк рдиреЗрдЯрд╡рд░реНрдХ рдХреЛ рдХреЗрд╡рд▓ рдХреБрдЫ рд▓рд╛рдЗрдиреЛрдВ рдореЗрдВ рдкреНрд░рд╢рд┐рдХреНрд╖рд┐рдд рдХрд░рдиреЗ рдХреЗ рд▓рд┐рдП рд╕рд╛рдЗрдХрд┐рд▓ рд▓рд┐рдЦ рд╕рдХрддреЗ рд╣реИрдВ, рдмреЙрдХреНрд╕ рд╕реЗ рдорд╛рдирдХ рдореИрдЯреНрд░рд┐рдХреНрд╕ рдЧрдгрдирд╛ рдЬреЛрдбрд╝ рд╕рдХрддреЗ рд╣реИрдВ, рдореЙрдбрд▓ рдХреЛ рдмрдЪрд╛ рд╕рдХрддреЗ рд╣реИрдВ, рдЖрджрд┐ред рдареАрдХ рд╣реИ, рдЙрди рд▓реЛрдЧреЛрдВ рдХреЗ рд▓рд┐рдП рдЬреЛ рдЯреАрдПрдл рд╕реЗ рдкрд┐рдпрд░рдЯреЗрдХ рдХреЗ рд▓рд┐рдП рдЪрд▓реЗ рдЧрдП рд╣реИрдВ, рд╣рдо рдХрд╣ рд╕рдХрддреЗ рд╣реИрдВ рдХрд┐ рдкреНрд░рдЬреНрд╡рд▓рд┐рдд рдкреБрд╕реНрддрдХрд╛рд▓рдп рдкрд╛рдпрд░ рдХреЗ рд▓рд┐рдП рдХреЗрд░рд╕ рд╣реИред
рдпрд╣ рд▓реЗрдЦ рд╡рд┐рд╕реНрддрд╛рд░ рд╕реЗ рдЖрдЧрдгрди рдХрд╛ рдЙрдкрдпреЛрдЧ рдХрд░рдХреЗ рдПрдХ рд╡рд░реНрдЧреАрдХрд░рдг рдХрд╛рд░реНрдп рдХреЗ рд▓рд┐рдП рдПрдХ рддрдВрддреНрд░рд┐рдХрд╛ рдиреЗрдЯрд╡рд░реНрдХ рдХреЗ рдкреНрд░рд╢рд┐рдХреНрд╖рдг рдХрд╛ рдПрдХ рдЙрджрд╛рд╣рд░рдг рд╡рд┐рд╕реНрддрд╛рд░ рд╕реЗ рдЬрд╛рдВрдЪ рдХрд░реЗрдЧрд╛ ред

PyTorch рдореЗрдВ рдЕрдзрд┐рдХ рдЖрдЧ рдЬреЛрдбрд╝реЗрдВ
рдореИрдВ рдЗрд╕ рдмрд╛рд░реЗ рдореЗрдВ рдмрд╛рдд рдХрд░рдиреЗ рдореЗрдВ рд╕рдордп рдмрд░реНрдмрд╛рдж рдирд╣реАрдВ рдХрд░реВрдВрдЧрд╛ рдХрд┐ рдкрд╛рдЗрд░реЙрдЪ рдлреНрд░реЗрдорд╡рд░реНрдХ рдХрд┐рддрдирд╛ рдЕрдЪреНрдЫрд╛ рд╣реИред рдЬреЛ рдХреЛрдИ рдкрд╣рд▓реЗ рд╕реЗ рд╣реА рдЗрд╕рдХрд╛ рдЗрд╕реНрддреЗрдорд╛рд▓ рдХрд░ рдЪреБрдХрд╛ рд╣реИ, рд╡рд╣ рд╕рдордЭрддрд╛ рд╣реИ рдХрд┐ рдореИрдВ рдХреНрдпрд╛ рд▓рд┐рдЦ тАЛтАЛрд░рд╣рд╛ рд╣реВрдВред рд▓реЗрдХрд┐рди, рдЕрдкрдиреЗ рд╕рднреА рдлрд╛рдпрджреЛрдВ рдХреЗ рд╕рд╛рде, рдпрд╣ рдЕрднреА рднреА рдкреНрд░рд╢рд┐рдХреНрд╖рдг, рдкрд░реАрдХреНрд╖рдг, рддрдВрддреНрд░рд┐рдХрд╛ рдиреЗрдЯрд╡рд░реНрдХ рдХреЗ рдкрд░реАрдХреНрд╖рдг рдХреЗ рд▓рд┐рдП рдЪрдХреНрд░ рд▓рд┐рдЦрдиреЗ рдХреЗ рдорд╛рдорд▓реЗ рдореЗрдВ рдирд┐рдореНрди рд╕реНрддрд░ рдХрд╛ рд╣реИред
рдЕрдЧрд░ рд╣рдо PyTorch рдврд╛рдВрдЪреЗ рдХрд╛ рдЙрдкрдпреЛрдЧ рдХрд░рдиреЗ рдХреЗ рдЖрдзрд┐рдХрд╛рд░рд┐рдХ рдЙрджрд╛рд╣рд░рдгреЛрдВ рдХреЛ рджреЗрдЦрддреЗ рд╣реИрдВ, рддреЛ рд╣рдо рдПрдкреЛрдЪ рджреНрд╡рд╛рд░рд╛ рдФрд░ рдЧреНрд░рд┐рдб рдкреНрд░рд╢рд┐рдХреНрд╖рдг рдХреЛрдб рдореЗрдВ рдирд┐рд░реНрдзрд╛рд░рд┐рдд рдкреНрд░рд╢рд┐рдХреНрд╖рдг рдХреЗ рдмреИрдЪреЛрдВ рджреНрд╡рд╛рд░рд╛ рдХрдо рд╕реЗ рдХрдо рджреЛ рдЪрдХреНрд░реЛрдВ рдХреЛ рджреЗрдЦреЗрдВрдЧреЗ:
for epoch in range(1, epochs + 1): for batch_idx, (data, target) in enumerate(train_loader):
рдЖрдЧреНрдиреЗрдп рдкреБрд╕реНрддрдХрд╛рд▓рдп рдХрд╛ рдореБрдЦреНрдп рд╡рд┐рдЪрд╛рд░ рдЗрди рд▓реВрдкреЛрдВ рдХреЛ рдПрдХ рд╣реА рд╡рд░реНрдЧ рдореЗрдВ рдмрджрд▓рдирд╛ рд╣реИ, рдЬрдмрдХрд┐ рдЙрдкрдпреЛрдЧрдХрд░реНрддрд╛ рдЗрди рд╣реИрдВрдбрд▓рд░ рдХрд╛ рдЙрдкрдпреЛрдЧ рдХрд░рдХреЗ рдЗрди рд▓реВрдкреЛрдВ рдХреЗ рд╕рд╛рде рдмрд╛рддрдЪреАрдд рдХрд░рдиреЗ рдХреА рдЕрдиреБрдорддрд┐ рджреЗрддрд╛ рд╣реИред
рдирддреАрдЬрддрди, рдорд╛рдирдХ рдЧрд╣рди рд╢рд┐рдХреНрд╖рдг рдХрд╛рд░реНрдпреЛрдВ рдХреЗ рдорд╛рдорд▓реЗ рдореЗрдВ, рд╣рдо рдХреЛрдб рдХреА рд▓рд╛рдЗрдиреЛрдВ рдХреА рд╕рдВрдЦреНрдпрд╛ рдкрд░ рдмрд╣реБрдд рдмрдЪрдд рдХрд░ рд╕рдХрддреЗ рд╣реИрдВред рдХрдо рд▓рд╛рдЗрдиреЗрдВ - рдХрдо рддреНрд░реБрдЯрд┐рдпрд╛рдВ!
рдЙрджрд╛рд╣рд░рдг рдХреЗ рд▓рд┐рдП, рддреБрд▓рдирд╛ рдХреЗ рд▓рд┐рдП, рдмрд╛рдИрдВ рдУрд░ рдкреНрд░рдЬреНрд╡рд▓рди рдХрд╛ рдЙрдкрдпреЛрдЧ рдХрд░рддреЗ рд╣реБрдП рдкреНрд░рд╢рд┐рдХреНрд╖рдг рдФрд░ рдореЙрдбрд▓ рд╕рддреНрдпрд╛рдкрди рдХреЗ рд▓рд┐рдП рдХреЛрдб рд╣реИ, рдФрд░ рджрд╛рдИрдВ рдУрд░ рд╢реБрджреНрдз Pyoror рд╣реИ:

рддреЛ рдлрд┐рд░, рдХреНрдпрд╛ рдкреНрд░рдЬреНрд╡рд▓рд┐рдд рдХреЗ рд▓рд┐рдП рдЕрдЪреНрдЫрд╛ рд╣реИ ?
- рдЕрдм рдЖрдкрдХреЛ рдкреНрд░рддреНрдпреЗрдХ рдХрд╛рд░реНрдп рдЫреЛрд░реЛрдВ рдХреЗ
for epoch in range(n_epochs)
for batch in data_loader
рдФрд░ for batch in data_loader
рд▓рд┐рдП рд▓рд┐рдЦрдиреЗ рдХреА рдЖрд╡рд╢реНрдпрдХрддрд╛ рдирд╣реАрдВ рд╣реИред - рдЖрдкрдХреЛ рдмреЗрд╣рддрд░ рдлреИрдХреНрдЯрд░ рдХреЛрдб рдмрдирд╛рдиреЗ рдХреА рдЕрдиреБрдорддрд┐ рджреЗрддрд╛ рд╣реИ
- рдЖрдк рдмреЙрдХреНрд╕ рд╕реЗ рдмрд╛рд╣рд░ рдмреБрдирд┐рдпрд╛рджреА рдореИрдЯреНрд░рд┐рдХреНрд╕ рдХреА рдЧрдгрдирд╛ рдХрд░рдиреЗ рдХреА рдЕрдиреБрдорддрд┐ рджреЗрддрд╛ рд╣реИ
- рдкреНрд░рдХрд╛рд░ рдХреЗ "рдмрдиреНрд╕" рдкреНрд░рджрд╛рди рдХрд░рддрд╛ рд╣реИ
- рдкреНрд░рд╢рд┐рдХреНрд╖рдг рдХреЗ рджреМрд░рд╛рди рдирд╡реАрдирддрдо рдФрд░ рд╕рд░реНрд╡рд╢реНрд░реЗрд╖реНрда рдореЙрдбрд▓ (рднреА рдЖрд╢рд╛рд╡рд╛рджреА рдФрд░ рд╕реАрдЦрдиреЗ рдХреА рджрд░ рдЕрдиреБрд╕реВрдЪрдХ) рдХреЛ рдмрдЪрд╛рддреЗ рд╣реБрдП,
- рдЬрд▓реНрджреА рд╕реАрдЦрдирд╛ рдмрдВрдж рдХрд░реЛ
- рдЖрджрд┐
- рдЖрд╕рд╛рдиреА рд╕реЗ рд╡рд┐рдЬрд╝реБрдЕрд▓рд╛рдЗрдЬрд╝реЗрд╢рди рдЯреВрд▓ рдХреЗ рд╕рд╛рде рдПрдХреАрдХреГрдд рд╣реЛрддрд╛ рд╣реИ: рдЯреЗрдВрд╕реЛрд░рдмреЛрд░реНрдб, рд╡рд┐рдЬрд╝рдбрдо, ...
рдПрдХ рдЕрд░реНрде рдореЗрдВ, рдЬреИрд╕рд╛ рдХрд┐ рдкрд╣рд▓реЗ рд╣реА рдЙрд▓реНрд▓реЗрдЦ рдХрд┐рдпрд╛ рдЧрдпрд╛ рд╣реИ, рдкреНрд░рдЬреНрд╡рд▓рд┐рдд рдкреБрд╕реНрддрдХрд╛рд▓рдп рдХреА рддреБрд▓рдирд╛ рдкреНрд░рд╢рд┐рдХреНрд╖рдг рдФрд░ рдкрд░реАрдХреНрд╖рдг рдиреЗрдЯрд╡рд░реНрдХ рдХреЗ рд▓рд┐рдП рд╕рднреА рдкреНрд░рд╕рд┐рджреНрдз рдХреЗрд░ рдФрд░ рдЗрд╕рдХреЗ рдПрдкреАрдЖрдИ рдХреЗ рд╕рд╛рде рдХреА рдЬрд╛ рд╕рдХрддреА рд╣реИред рдЗрд╕рдХреЗ рдЕрд▓рд╛рд╡рд╛, рдкрд╣рд▓реА рдирдЬрд╝рд░ рдореЗрдВ рдкреНрд░рдЬреНрд╡рд▓рд┐рдд рдкреБрд╕реНрддрдХрд╛рд▓рдп, tnt рдкреБрд╕реНрддрдХрд╛рд▓рдп рдХреЗ рд╕рдорд╛рди рд╣реИ, рдХреНрдпреЛрдВрдХрд┐ рд╢реБрд░реВ рдореЗрдВ рджреЛрдиреЛрдВ рдкреБрд╕реНрддрдХрд╛рд▓рдпреЛрдВ рдореЗрдВ рд╕рд╛рдорд╛рдиреНрдп рд▓рдХреНрд╖реНрдп рдереЗ рдФрд░ рдЙрдирдХреЗ рдХрд╛рд░реНрдпрд╛рдиреНрд╡рдпрди рдХреЗ рд▓рд┐рдП рд╕рдорд╛рди рд╡рд┐рдЪрд╛рд░ рдереЗред
рддреЛ, рдкреНрд░рдХрд╛рд╢:
pip install pytorch-ignite
рдпрд╛
conda install ignite -c pytorch
рдЕрдЧрд▓рд╛, рдПрдХ рд╡рд┐рд╢рд┐рд╖реНрдЯ рдЙрджрд╛рд╣рд░рдг рдХреЗ рд╕рд╛рде, рд╣рдо рдЦреБрдж рдХреЛ рдЗрдЧреНрдирд╛рдЗрдЯ рд▓рд╛рдЗрдмреНрд░реЗрд░реА рдПрдкреАрдЖрдИ рдХреЗ рд╕рд╛рде рдкрд░рд┐рдЪрд┐рдд рдХрд░реЗрдВрдЧреЗред
рдЖрдЧреНрдиреЗрдп рдХреЗ рд╕рд╛рде рд╡рд░реНрдЧреАрдХрд░рдг рдХрд╛рд░реНрдп
рд▓реЗрдЦ рдХреЗ рдЗрд╕ рднрд╛рдЧ рдореЗрдВ, рд╣рдо рдкреНрд░рдЬреНрд╡рд▓рд┐рдд рдкреБрд╕реНрддрдХрд╛рд▓рдп рдХрд╛ рдЙрдкрдпреЛрдЧ рдХрд░рдХреЗ рд╡рд░реНрдЧреАрдХрд░рдг рд╕рдорд╕реНрдпрд╛ рдХреЗ рд▓рд┐рдП рдПрдХ рддрдВрддреНрд░рд┐рдХрд╛ рдиреЗрдЯрд╡рд░реНрдХ рдХреЗ рдкреНрд░рд╢рд┐рдХреНрд╖рдг рдХреЗ рдПрдХ рд╕реНрдХреВрд▓ рдЙрджрд╛рд╣рд░рдг рдкрд░ рд╡рд┐рдЪрд╛рд░ рдХрд░реЗрдВрдЧреЗред
рддреЛ, рдЖрдЗрдП рдПрдХ рд╕рд░рд▓ рдбреЗрдЯрд╛рд╕реЗрдЯ рд▓реЗрдВ рдЬрд┐рд╕рдореЗрдВ рдХреЗрдЧрд▓ рдХреЗ рд╕рд╛рде рдлрд▓реЛрдВ рдХреЗ рдЪрд┐рддреНрд░ рд╣реИрдВ ред рдХрд╛рд░реНрдп рдкреНрд░рддреНрдпреЗрдХ рдлрд▓ рдЪрд┐рддреНрд░ рдХреЗ рд╕рд╛рде рд╕рдВрдмрдВрдзрд┐рдд рд╡рд░реНрдЧ рдХреЛ рдЬреЛрдбрд╝рдирд╛ рд╣реИред
рдЗрдЧреНрдирд╛рдЗрдЯ рдХрд╛ рдЙрдкрдпреЛрдЧ рдХрд░рдиреЗ рд╕реЗ рдкрд╣рд▓реЗ, рдЖрдЗрдП рдореБрдЦреНрдп рдШрдЯрдХреЛрдВ рдХреЛ рдкрд░рд┐рднрд╛рд╖рд┐рдд рдХрд░реЗрдВ:
рдбреЗрдЯрд╛ рд╕реНрдЯреНрд░реАрдо
- рдкреНрд░рд╢рд┐рдХреНрд╖рдг рдирдореВрдирд╛ рдмреИрдЪрд░ рд▓реЛрдбрд░,
train_loader
- рдЪреЗрдХрдЖрдЙрдЯ рдмреИрдЪ рдбрд╛рдЙрдирд▓реЛрдбрд░,
val_loader
рдореЙрдбрд▓:
- рдорд╢рд╛рд▓ рд╕реЗ рдЫреЛрдЯреЗ рдирд┐рдЪреЛрдбрд╝рдиреЗ рдХрд╛ рдЧреНрд░рд┐рдб рд▓реЗ
torchvision
рдЕрдиреБрдХреВрд▓рди рдПрд▓реНрдЧреЛрд░рд┐рдердо:
рдиреБрдХрд╕рд╛рди рд╕рдорд╛рд░реЛрд╣:
- рдХреНрд░реЛрд╕ рдПрдВрдЯреНрд░реЛрдкреА
рдХреЛрдб from pathlib import Path import numpy as np import torch from torch.utils.data import Dataset, DataLoader from torch.utils.data.dataset import Subset from torchvision.datasets import ImageFolder from torchvision.transforms import Compose, RandomResizedCrop, RandomVerticalFlip, RandomHorizontalFlip from torchvision.transforms import ColorJitter, ToTensor, Normalize FRUIT360_PATH = Path(".").resolve().parent / "input" / "fruits-360_dataset" / "fruits-360" device = "cuda" train_transform = Compose([ RandomHorizontalFlip(), RandomResizedCrop(size=32), ColorJitter(brightness=0.12), ToTensor(), Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) ]) val_transform = Compose([ RandomResizedCrop(size=32), ToTensor(), Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) ]) batch_size = 128 num_workers = 8 train_dataset = ImageFolder((FRUIT360_PATH /"Training").as_posix(), transform=train_transform, target_transform=None) val_dataset = ImageFolder((FRUIT360_PATH /"Test").as_posix(), transform=val_transform, target_transform=None) train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers, drop_last=True, pin_memory="cuda" in device) val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers, drop_last=False, pin_memory="cuda" in device)
import torch.nn as nn from torchvision.models.squeezenet import squeezenet1_1 model = squeezenet1_1(pretrained=False, num_classes=81) model.classifier[-1] = nn.AdaptiveAvgPool2d(1) model = model.to(device)
import torch.nn as nn from torch.optim import SGD optimizer = SGD(model.parameters(), lr=0.01, momentum=0.5) criterion = nn.CrossEntropyLoss()
рддреЛ рдЕрдм рдкреНрд░рдЬреНрд╡рд▓рд┐рдд рд╣реЛрдиреЗ рдХрд╛ рд╕рдордп рд╣реИ:
from ignite.engine import Engine, _prepare_batch def process_function(engine, batch): model.train() optimizer.zero_grad() x, y = _prepare_batch(batch, device=device) y_pred = model(x) loss = criterion(y_pred, y) loss.backward() optimizer.step() return loss.item() trainer = Engine(process_function)
рдЖрдЗрдП рджреЗрдЦреЗрдВ рдХрд┐ рдЗрд╕ рдХреЛрдб рдХрд╛ рдХреНрдпрд╛ рдЕрд░реНрде рд╣реИред
рдЗрдВрдЬрди рдХрд╛ Engine
ignite.engine.Engine
рд▓рд╛рдЗрдмреНрд░реЗрд░реА рд▓рд╛рдЗрдмреНрд░реЗрд░реА рдлреНрд░реЗрдорд╡рд░реНрдХ рд╣реИ, рдФрд░ рдЗрд╕ рдХреНрд▓рд╛рд╕ рдХрд╛ рдЙрджреНрджреЗрд╢реНрдп trainer
:
trainer = Engine(process_function)
рдпрд╣ рдПрдХ рдмреИрдЪ рдХреЗ рдкреНрд░рд╕рдВрд╕реНрдХрд░рдг рдХреЗ рд▓рд┐рдП рдЗрдирдкреБрдЯ рдлрдВрдХреНрд╢рди process_function
рд╕рд╛рде рдкрд░рд┐рднрд╛рд╖рд┐рдд рдХрд┐рдпрд╛ рдЧрдпрд╛ рд╣реИ рдФрд░ рдкреНрд░рд╢рд┐рдХреНрд╖рдг рдирдореВрдиреЗ рдХреЗ рд▓рд┐рдП рдкрд╛рд╕ рд▓рд╛рдЧреВ рдХрд░рдиреЗ рдХреЗ рд▓рд┐рдП рдХрд╛рд░реНрдп рдХрд░рддрд╛ рд╣реИред ignite.engine.Engine
рдХреНрд▓рд╛рд╕ рдХреЗ рдЕрдВрджрд░, рдирд┐рдореНрди рд╣реЛрддрд╛ рд╣реИ:
while epoch < max_epochs:
рд╡рд╛рдкрд╕ рдХрд░рдиреЗ рдХреЗ рд▓рд┐рдП process_function
рдлрд╝рдВрдХреНрд╢рди:
def process_function(engine, batch): model.train() optimizer.zero_grad() x, y = _prepare_batch(batch, device=device) y_pred = model(x) loss = criterion(y_pred, y) loss.backward() optimizer.step() return loss.item()
рд╣рдо рджреЗрдЦрддреЗ рд╣реИрдВ рдХрд┐ рдлрд╝рдВрдХреНрд╢рди рдХреЗ рдЕрдВрджрд░, рд╣рдо рд╣рдореЗрд╢рд╛ рдХреА рддрд░рд╣, рдореЙрдбрд▓ рдкреНрд░рд╢рд┐рдХреНрд╖рдг рдХреЗ рдорд╛рдорд▓реЗ рдореЗрдВ, y_pred
рднрд╡рд┐рд╖реНрдпрд╡рд╛рдгрд┐рдпреЛрдВ рдХреА рдЧрдгрдирд╛ рдХрд░рддреЗ рд╣реИрдВ, рд╣рд╛рдирд┐ рдлрд╝рдВрдХреНрд╢рди, loss
рдФрд░ рдЧреНрд░реЗрдбрд┐рдПрдВрдЯреНрд╕ рдХреА рдЧрдгрдирд╛ рдХрд░рддреЗ рд╣реИрдВред рдЙрддреНрддрд░рд╛рд░реНрджреНрдз рдЖрдкрдХреЛ рдореЙрдбрд▓ рд╡рдЬрди рдЕрдкрдбреЗрдЯ рдХрд░рдиреЗ рдХреА рдЕрдиреБрдорддрд┐ рджреЗрддрд╛ рд╣реИ: optimizer.step()
ред
рд╕рд╛рдорд╛рдиреНрдп рддреМрд░ рдкрд░, process_function
рдлрд╝рдВрдХреНрд╢рди рдХреЗ рдХреЛрдб рдкрд░ рдХреЛрдИ рдкреНрд░рддрд┐рдмрдВрдз рдирд╣реАрдВ рд╣реИред рд╣рдо рдХреЗрд╡рд▓ рдпрд╣ рдиреЛрдЯ рдХрд░рддреЗ рд╣реИрдВ рдХрд┐ рдЗрдирдкреБрдЯ рдХреЗ рд░реВрдк рдореЗрдВ рджреЛ рддрд░реНрдХ рд╣реИрдВ: Engine
рдСрдмреНрдЬреЗрдХреНрдЯ (рд╣рдорд╛рд░реЗ рдорд╛рдорд▓реЗ рдореЗрдВ, trainer
) рдФрд░ рдбреЗрдЯрд╛ рд▓реЛрдбрд░ рд╕реЗ рдмреИрдЪред рдЗрд╕рд▓рд┐рдП, рдЙрджрд╛рд╣рд░рдг рдХреЗ рд▓рд┐рдП, рдПрдХ рддрдВрддреНрд░рд┐рдХрд╛ рдиреЗрдЯрд╡рд░реНрдХ рдХрд╛ рдкрд░реАрдХреНрд╖рдг рдХрд░рдиреЗ рдХреЗ рд▓рд┐рдП, рд╣рдо ignite.engine.Engine
рд╡рд░реНрдЧ рдХреА рдПрдХ рдЕрдиреНрдп рд╡рд╕реНрддреБ рдХреЛ рдкрд░рд┐рднрд╛рд╖рд┐рдд рдХрд░ рд╕рдХрддреЗ рд╣реИрдВ, рдЬрд┐рд╕рдореЗрдВ рдЗрдирдкреБрдЯ рдлрд╝рдВрдХреНрд╢рди рдХреЗрд╡рд▓ рднрд╡рд┐рд╖реНрдпрд╡рд╛рдгрд┐рдпреЛрдВ рдХреА рдЧрдгрдирд╛ рдХрд░рддрд╛ рд╣реИ, рдФрд░ рдПрдХ рдмрд╛рд░ рдкрд░реАрдХреНрд╖рдг рдирдореВрдиреЗ рдХреЗ рдорд╛рдзреНрдпрдо рд╕реЗ рдПрдХ рдкрд╛рд╕ рд▓рд╛рдЧреВ рдХрд░рддрд╛ рд╣реИред рдЗрд╕рдХреЗ рдмрд╛рд░реЗ рдореЗрдВ рдмрд╛рдж рдореЗрдВ рдкрдврд╝реЗрдВред
рддреЛ, рдЙрдкрд░реЛрдХреНрдд рдХреЛрдб рдХреЗрд╡рд▓ рдкреНрд░рд╢рд┐рдХреНрд╖рдг рд╢реБрд░реВ рдХрд┐рдП рдмрд┐рдирд╛ рдЖрд╡рд╢реНрдпрдХ рд╡рд╕реНрддреБрдУрдВ рдХреЛ рдкрд░рд┐рднрд╛рд╖рд┐рдд рдХрд░рддрд╛ рд╣реИред рдореВрд▓ рд░реВрдк рд╕реЗ, рдПрдХ рдиреНрдпреВрдирддрдо рдЙрджрд╛рд╣рд░рдг рдореЗрдВ, рдЖрдк рд╡рд┐рдзрд┐ рдХреЛ рдХреЙрд▓ рдХрд░ рд╕рдХрддреЗ рд╣реИрдВ:
trainer.run(train_loader, max_epochs=10)
рдФрд░ рдпрд╣ рдХреЛрдб "рдЪреБрдкрдЪрд╛рдк" (рдордзреНрдпрд╡рд░реНрддреА рдкрд░рд┐рдгрд╛рдореЛрдВ рдХреЗ рдХрд┐рд╕реА рднреА рд╡реНрдпреБрддреНрдкрддреНрддрд┐ рдХреЗ рдмрд┐рдирд╛) рдореЙрдбрд▓ рдХреЛ рдкреНрд░рд╢рд┐рдХреНрд╖рд┐рдд рдХрд░рдиреЗ рдХреЗ рд▓рд┐рдП рдкрд░реНрдпрд╛рдкреНрдд рд╣реИред
рдПрдХ рдиреЛрдЯрдпрд╣ рднреА рдзреНрдпрд╛рди рджреЗрдВ рдХрд┐ рдЗрд╕ рдкреНрд░рдХрд╛рд░ рдХреЗ рдХрд╛рд░реНрдпреЛрдВ рдХреЗ рд▓рд┐рдП рд▓рд╛рдЗрдмреНрд░реЗрд░реА рдореЗрдВ trainer
рдСрдмреНрдЬреЗрдХреНрдЯ рдмрдирд╛рдиреЗ рдХреЗ рд▓рд┐рдП рдПрдХ рд╕реБрд╡рд┐рдзрд╛рдЬрдирдХ рддрд░реАрдХрд╛ рд╣реИ:
from ignite.engine import create_supervised_trainer trainer = create_supervised_trainer(model, optimizer, criterion, device)
рдмреЗрд╢рдХ, рд╡реНрдпрд╡рд╣рд╛рд░ рдореЗрдВ, рдЙрдкрд░реЛрдХреНрдд рдЙрджрд╛рд╣рд░рдг рдереЛрдбрд╝рд╛ рдмреНрдпрд╛рдЬ рдХрд╛ рд╣реИ, рдЗрд╕рд▓рд┐рдП "рдХреЛрдЪ" рдХреЗ рд▓рд┐рдП рдирд┐рдореНрдирд▓рд┐рдЦрд┐рдд рд╡рд┐рдХрд▓реНрдк рдЬреЛрдбрд╝реЗрдВред
- рдкреНрд░рддреНрдпреЗрдХ 50 рдкреБрдирд░рд╛рд╡реГрддреНрддрд┐рдпреЛрдВ рдХреЛ рдиреБрдХрд╕рд╛рди рдлрд╝рдВрдХреНрд╢рди рдХрд╛ рдкреНрд░рджрд░реНрд╢рди
- рдПрдХ рдирд┐рд░реНрдзрд╛рд░рд┐рдд рдореЙрдбрд▓ рдХреЗ рд╕рд╛рде рдкреНрд░рд╢рд┐рдХреНрд╖рдг рд╕реЗрдЯ рдкрд░ рдореИрдЯреНрд░рд┐рдХреНрд╕ рдХреА рдЧрдгрдирд╛ рдХреА рд╢реБрд░реБрдЖрдд
- рдкреНрд░рддреНрдпреЗрдХ рдпреБрдЧ рдХреЗ рдмрд╛рдж рдкрд░реАрдХреНрд╖рдг рдирдореВрдиреЗ рдкрд░ рдореИрдЯреНрд░рд┐рдХреНрд╕ рдХреА рдЧрдгрдирд╛ рдХреА рд╢реБрд░реБрдЖрдд
- рдкреНрд░рддреНрдпреЗрдХ рдпреБрдЧ рдХреЗ рдмрд╛рдж рдореЙрдбрд▓ рдорд╛рдкрджрдВрдбреЛрдВ рдХреЛ рд╕рд╣реЗрдЬрдирд╛
- рддреАрди рд╕рд░реНрд╡рд╢реНрд░реЗрд╖реНрда рдореЙрдбрд▓реЛрдВ рдХрд╛ рд╕рдВрд░рдХреНрд╖рдг
- рдпреБрдЧ рдХреЗ рдЖрдзрд╛рд░ рдкрд░ рд╕реАрдЦрдиреЗ рдХреА рдЧрддрд┐ рдореЗрдВ рдмрджрд▓рд╛рд╡ (рд╕реАрдЦрдиреЗ рдХреА рджрд░ рдирд┐рд░реНрдзрд╛рд░рдг)
- рдкреНрд░рд╛рд░рдВрднрд┐рдХ рд░реЛрдХ рдкреНрд░рд╢рд┐рдХреНрд╖рдг (рдкреНрд░рд╛рд░рдВрднрд┐рдХ рд░реЛрдХ)
рдЗрд╡реЗрдВрдЯреНрд╕ рдПрдВрдб рдЗрд╡реЗрдВрдЯ рд╣реИрдВрдбрд▓рд░
"рдЯреНрд░реЗрдирд░" рдХреЗ рд▓рд┐рдП рдЙрдкрд░реЛрдХреНрдд рд╡рд┐рдХрд▓реНрдкреЛрдВ рдХреЛ рдЬреЛрдбрд╝рдиреЗ рдХреЗ рд▓рд┐рдП, рдЗрдЧреНрдирд╛рдЗрдЯ рд▓рд╛рдЗрдмреНрд░реЗрд░реА рдПрдХ рдЗрд╡реЗрдВрдЯ рд╕рд┐рд╕реНрдЯрдо рдФрд░ рдХрд╕реНрдЯрдо рдЗрд╡реЗрдВрдЯ рд╣реИрдВрдбрд▓рд░ рдХреЗ рд▓реЙрдиреНрдЪ рдХреА рд╕реБрд╡рд┐рдзрд╛ рдкреНрд░рджрд╛рди рдХрд░рддрд╛ рд╣реИред рдЗрд╕ рдкреНрд░рдХрд╛рд░, рдЙрдкрдпреЛрдЧрдХрд░реНрддрд╛ рдкреНрд░рддреНрдпреЗрдХ рдЪрд░рдг рдореЗрдВ Engine
рд╡рд░реНрдЧ рдХреА рдПрдХ рд╡рд╕реНрддреБ рдХреЛ рдирд┐рдпрдВрддреНрд░рд┐рдд рдХрд░ рд╕рдХрддрд╛ рд╣реИ:
- рдЗрдВрдЬрди рд╢реБрд░реВ / рдкреВрд░рд╛ рд╣реБрдЖ
- рдпреБрдЧ рд╢реБрд░реВ рд╣реБрдЖ / рд╕рдорд╛рдкреНрдд рд╣реБрдЖ
- рдмреИрдЪ рдкреБрдирд░рд╛рд╡реГрддреНрддрд┐ рд╢реБрд░реВ / рд╕рдорд╛рдкреНрдд рд╣реЛ рдЧрдпрд╛
рдФрд░ рд╣рд░ рдШрдЯрдирд╛ рдкрд░ рдЕрдкрдирд╛ рдХреЛрдб рдЪрд▓рд╛рдПрдВред
рд╣рд╛рдирд┐ рдлрд╝рдВрдХреНрд╢рди рдорд╛рди рдкреНрд░рджрд░реНрд╢рд┐рдд рдХрд░рддрд╛ рд╣реИ
рдРрд╕рд╛ рдХрд░рдиреЗ рдХреЗ рд▓рд┐рдП, рдмрд╕ рдЙрд╕ рдлрд╝рдВрдХреНрд╢рди рдХреЛ рдирд┐рд░реНрдзрд╛рд░рд┐рдд рдХрд░реЗрдВ рдЬрд┐рд╕рдореЗрдВ рдЖрдЙрдЯрдкреБрдЯ рд╕реНрдХреНрд░реАрди рдкрд░ рдкреНрд░рджрд░реНрд╢рд┐рдд рд╣реЛрдЧрд╛, рдФрд░ рдЗрд╕реЗ "рдЯреНрд░реЗрдирд░" рдореЗрдВ рдЬреЛрдбрд╝реЗрдВ:
from ignite.engine import Events log_interval = 50 @trainer.on(Events.ITERATION_COMPLETED) def log_training_loss(engine): iteration = (engine.state.iteration - 1) % len(train_loader) + 1 if iteration % log_interval == 0: print("Epoch[{}] Iteration[{}/{}] Loss: {:.4f}" .format(engine.state.epoch, iteration, len(train_loader), engine.state.output))
рдЗрд╡реЗрдВрдЯ рд╣реИрдВрдбрд▓рд░ рдЬреЛрдбрд╝рдиреЗ рдХреЗ рд╡рд╛рд╕реНрддрд╡ рдореЗрдВ рджреЛ рддрд░реАрдХреЗ рд╣реИрдВ: add_event_handler
рдорд╛рдзреНрдпрдо рд╕реЗ, рдпрд╛ рдбреЗрдХреЛрд░реЗрдЯрд░ рдХреЗ рдорд╛рдзреНрдпрдо рд╕реЗред рдКрдкрд░ рдЬреИрд╕рд╛ рдЗрд╕ рдкреНрд░рдХрд╛рд░ рдХрд┐рдпрд╛ рдЬрд╛ рд╕рдХрддрд╛ рд╣реИ:
from ignite.engine import Events log_interval = 50 def log_training_loss(engine):
рдзреНрдпрд╛рди рджреЗрдВ рдХрд┐ рдХрд┐рд╕реА рднреА рддрд░реНрдХ рдХреЛ рдЗрд╡реЗрдВрдЯ рд╣реИрдВрдбрд▓рд┐рдВрдЧ рдлрд╝рдВрдХреНрд╢рди рдореЗрдВ рдкрд╛рд╕ рдХрд┐рдпрд╛ рдЬрд╛ рд╕рдХрддрд╛ рд╣реИред рд╕рд╛рдорд╛рдиреНрдп рддреМрд░ рдкрд░, рдРрд╕рд╛ рдлрд╝рдВрдХреНрд╢рди рдЗрд╕ рддрд░рд╣ рджрд┐рдЦреЗрдЧрд╛:
def custom_handler(engine, *args, **kwargs): pass trainer.add_event_handler(Events.ITERATION_COMPLETED, custom_handler, *args, **kwargs)
рддреЛ, рдЖрдЗрдП рдПрдХ рдпреБрдЧ рдкрд░ рдкреНрд░рд╢рд┐рдХреНрд╖рдг рд╢реБрд░реВ рдХрд░реЗрдВ рдФрд░ рджреЗрдЦреЗрдВ рдХрд┐ рдХреНрдпрд╛ рд╣реЛрддрд╛ рд╣реИ:
output = trainer.run(train_loader, max_epochs=1)
Epoch[1] Iteration[50/322] Loss: 4.3459 Epoch[1] Iteration[100/322] Loss: 4.2801 Epoch[1] Iteration[150/322] Loss: 4.2294 Epoch[1] Iteration[200/322] Loss: 4.1467 Epoch[1] Iteration[250/322] Loss: 3.8607 Epoch[1] Iteration[300/322] Loss: 3.6688
рдмреБрд░рд╛ рдирд╣реАрдВ рд╣реИ! рдЖрдЧреЗ рдЪрд▓рддреЗ рд╣реИрдВред
рдкреНрд░рд╢рд┐рдХреНрд╖рдг рдФрд░ рдкрд░реАрдХреНрд╖рдг рдирдореВрдиреЛрдВ рдкрд░ рдореИрдЯреНрд░рд┐рдХреНрд╕ рдХреА рдЧрдгрдирд╛ рд╢реБрд░реВ рдХрд░рдирд╛
рдЖрдЗрдП рдирд┐рдореНрдирд▓рд┐рдЦрд┐рдд рдореАрдЯреНрд░рд┐рдХ рдХреА рдЧрдгрдирд╛ рдХрд░реЗрдВ: рдкреНрд░рд╢рд┐рдХреНрд╖рдг рдХреА рдУрд░ рд╕реЗ рдкреНрд░рддреНрдпреЗрдХ рдпреБрдЧ рдХреЗ рдмрд╛рдж рдФрд╕рдд рд╕рдЯреАрдХрддрд╛, рдФрд╕рдд рдкреВрд░реНрдгрддрд╛ рдФрд░ рд╕рдВрдкреВрд░реНрдг рдкрд░реАрдХреНрд╖рдг рдирдореВрдирд╛ред рдзреНрдпрд╛рди рджреЗрдВ рдХрд┐ рд╣рдо рдкреНрд░рд╢рд┐рдХреНрд╖рдг рдХреЗ рдкреНрд░рддреНрдпреЗрдХ рдпреБрдЧ рдХреЗ рдмрд╛рдж рдкреНрд░рд╢рд┐рдХреНрд╖рдг рдирдореВрдиреЗ рдХреЗ рд╣рд┐рд╕реНрд╕реЗ рдкрд░ рдореИрдЯреНрд░рд┐рдХреНрд╕ рдХреА рдЧрдгрдирд╛ рдХрд░реЗрдВрдЧреЗ, рдФрд░ рдкреНрд░рд╢рд┐рдХреНрд╖рдг рдХреЗ рджреМрд░рд╛рди рдирд╣реАрдВред рдЗрд╕ рдкреНрд░рдХрд╛рд░, рджрдХреНрд╖рддрд╛ рдХрд╛ рдорд╛рдк рдЕрдзрд┐рдХ рд╕рдЯреАрдХ рд╣реЛрдЧрд╛, рдХреНрдпреЛрдВрдХрд┐ рдЧрдгрдирд╛ рдХреЗ рджреМрд░рд╛рди рдореЙрдбрд▓ рдирд╣реАрдВ рдмрджрд▓рддрд╛ рд╣реИред
рдЗрд╕рд▓рд┐рдП, рд╣рдо рдореИрдЯреНрд░рд┐рдХреНрд╕ рдХреЛ рдкрд░рд┐рднрд╛рд╖рд┐рдд рдХрд░рддреЗ рд╣реИрдВ:
from ignite.metrics import Loss, CategoricalAccuracy, Precision, Recall metrics = { 'avg_loss': Loss(criterion), 'avg_accuracy': CategoricalAccuracy(), 'avg_precision': Precision(average=True), 'avg_recall': Recall(average=True) }
рдЕрдЧрд▓рд╛, рд╣рдо ignite.engine.create_supervised_evaluator
рдХрд╛ рдЙрдкрдпреЛрдЧ рдХрд░рдХреЗ рдореЙрдбрд▓ рдХрд╛ рдореВрд▓реНрдпрд╛рдВрдХрди рдХрд░рдиреЗ рдХреЗ рд▓рд┐рдП рджреЛ рдЗрдВрдЬрди ignite.engine.create_supervised_evaluator
:
from ignite.engine import create_supervised_evaluator
рд╣рдо рдореЙрдбрд▓ рдХреЛ рдмрдЪрд╛рдиреЗ рдФрд░ рдЬрд▓реНрджреА рд╕реАрдЦрдиреЗ (рдЗрди рд╕рднреА рдХреЗ рдмрд╛рд░реЗ рдореЗрдВ рдиреАрдЪреЗ) рдХреЛ рд░реЛрдХрдиреЗ рдХреЗ рд▓рд┐рдП рдЙрдирдореЗрдВ рд╕реЗ рдПрдХ рдореЗрдВ рдЕрддрд┐рд░рд┐рдХреНрдд рдЗрд╡реЗрдВрдЯ рд╣реИрдВрдбрд▓рд░ рд╕рдВрд▓рдЧреНрди рдХрд░рдиреЗ рдХреЗ рд▓рд┐рдП рджреЛ рдЗрдВрдЬрди рдмрдирд╛ рд░рд╣реЗ рд╣реИрдВред
рдЖрдЗрдП рдЗрд╕ рдмрд╛рдд рдХрд╛ рднреА рдзреНрдпрд╛рди рд░рдЦреЗрдВ рдХрд┐ рдореЙрдбрд▓ рдХреЗ рдореВрд▓реНрдпрд╛рдВрдХрди рдХреЗ рд▓рд┐рдП рдЗрдВрдЬрди рдХреЛ рдХреИрд╕реЗ рдкрд░рд┐рднрд╛рд╖рд┐рдд рдХрд┐рдпрд╛ рдЧрдпрд╛ рд╣реИ, рдЕрд░реНрдерд╛рддреН рдЗрдирдкреБрдЯ рдлрд╝рдВрдХреНрд╢рди рдкреНрд░рдХреНрд░рд┐рдпрд╛_рдлрдВрдХреНрд╢рди рдХреЛ рдПрдХ рдмреИрдЪ рдХреЛ рдХреИрд╕реЗ рд╕рдВрд╕рд╛рдзрд┐рдд рдХрд┐рдпрд╛ рдЬрд╛рддрд╛ рд╣реИ:
def create_supervised_evaluator(model, metrics={}, device=None): if device: model.to(device) def _inference(engine, batch): model.eval() with torch.no_grad(): x, y = _prepare_batch(batch, device=device) y_pred = model(x) return y_pred, y engine = Engine(_inference) for name, metric in metrics.items(): metric.attach(engine, name) return engine
рд╣рдо рдЖрдЧреЗ рднреА рдЬрд╛рд░реА рд░рд╣реЗред рдЖрдЗрдП рд╣рдо рдпрд╛рджреГрдЪреНрдЫрд┐рдХ рд░реВрдк рд╕реЗ рдкреНрд░рд╢рд┐рдХреНрд╖рдг рдирдореВрдиреЗ рдХреЗ рдЙрд╕ рднрд╛рдЧ рдХрд╛ рдЪрдпрди рдХрд░реЗрдВ, рдЬрд┐рд╕ рдкрд░ рд╣рдо рдореИрдЯреНрд░рд┐рдХреНрд╕ рдХреА рдЧрдгрдирд╛ рдХрд░реЗрдВрдЧреЗ:
import numpy as np from torch.utils.data.dataset import Subset indices = np.arange(len(train_dataset)) random_indices = np.random.permutation(indices)[:len(val_dataset)] train_subset = Subset(train_dataset, indices=random_indices) train_eval_loader = DataLoader(train_subset, batch_size=batch_size, shuffle=True, num_workers=num_workers, drop_last=True, pin_memory="cuda" in device)
рдЕрдЧрд▓рд╛, рдЖрдЗрдП рдирд┐рд░реНрдзрд╛рд░рд┐рдд рдХрд░реЗрдВ рдХрд┐ рдкреНрд░рд╢рд┐рдХреНрд╖рдг рдореЗрдВ рд╣рдо рдХрд┐рд╕ рдмрд┐рдВрджреБ рдкрд░ рдореИрдЯреНрд░рд┐рдХреНрд╕ рдХреА рдЧрдгрдирд╛ рд╢реБрд░реВ рдХрд░реЗрдВрдЧреЗ рдФрд░ рд╕реНрдХреНрд░реАрди рдкрд░ рдЖрдЙрдЯрдкреБрдЯ рдХрд░реЗрдВрдЧреЗ:
@trainer.on(Events.EPOCH_COMPLETED) def compute_and_display_offline_train_metrics(engine): epoch = engine.state.epoch print("Compute train metrics...") metrics = train_evaluator.run(train_eval_loader).metrics print("Training Results - Epoch: {} Average Loss: {:.4f} | Accuracy: {:.4f} | Precision: {:.4f} | Recall: {:.4f}" .format(engine.state.epoch, metrics['avg_loss'], metrics['avg_accuracy'], metrics['avg_precision'], metrics['avg_recall'])) @trainer.on(Events.EPOCH_COMPLETED) def compute_and_display_val_metrics(engine): epoch = engine.state.epoch print("Compute validation metrics...") metrics = val_evaluator.run(val_loader).metrics print("Validation Results - Epoch: {} Average Loss: {:.4f} | Accuracy: {:.4f} | Precision: {:.4f} | Recall: {:.4f}" .format(engine.state.epoch, metrics['avg_loss'], metrics['avg_accuracy'], metrics['avg_precision'], metrics['avg_recall']))
рддреБрдо рджреМрдбрд╝ рд╕рдХрддреЗ рд╣реЛ!
output = trainer.run(train_loader, max_epochs=1)
рд╣рдо рд╕реНрдХреНрд░реАрди рдкрд░ рдЖрддреЗ рд╣реИрдВ
Epoch[1] Iteration[50/322] Loss: 3.5112 Epoch[1] Iteration[100/322] Loss: 2.9840 Epoch[1] Iteration[150/322] Loss: 2.8807 Epoch[1] Iteration[200/322] Loss: 2.9285 Epoch[1] Iteration[250/322] Loss: 2.5026 Epoch[1] Iteration[300/322] Loss: 2.1944 Compute train metrics... Training Results - Epoch: 1 Average Loss: 2.1018 | Accuracy: 0.3699 | Precision: 0.3981 | Recall: 0.3686 Compute validation metrics... Validation Results - Epoch: 1 Average Loss: 2.0519 | Accuracy: 0.3850 | Precision: 0.3578 | Recall: 0.3845
рдкрд╣рд▓реЗ рд╕реЗ рдмреЗрд╣рддрд░!
рдХреБрдЫ рд╡рд┐рд╡рд░рдг
рдЖрдЗрдП рдкрд┐рдЫрд▓реЗ рдХреЛрдб рдХреЛ рдереЛрдбрд╝рд╛ рджреЗрдЦреЗрдВред рдкрд╛рдардХ рдиреЗ рдХреЛрдб рдХреА рдирд┐рдореНрдирд▓рд┐рдЦрд┐рдд рдкрдВрдХреНрддрд┐ рджреЗрдЦреА рд╣реЛрдЧреА:
metrics = train_evaluator.run(train_eval_loader).metrics
рдФрд░ рд╕рдВрднрд╡рдд: train_evaluator.run(train_eval_loader)
рд╕реЗ рдкреНрд░рд╛рдкреНрдд рдСрдмреНрдЬреЗрдХреНрдЯ рдХреЗ рдкреНрд░рдХрд╛рд░ рдХреЗ рдмрд╛рд░реЗ рдореЗрдВ рдПрдХ рд╕рд╡рд╛рд▓ рдерд╛, рдЬрд┐рд╕рдореЗрдВ metrics
рд╡рд┐рд╢реЗрд╖рддрд╛ рд╣реИред
рд╡рд╛рд╕реНрддрд╡ рдореЗрдВ, Engine
рд╡рд░реНрдЧ рдореЗрдВ рдПрдХ рд╕рдВрд░рдЪрдирд╛ рд╣реЛрддреА рд╣реИ рдЬрд┐рд╕реЗ state
(рдЯрд╛рдЗрдк State
) рдХрд╣рд╛ рдЬрд╛рддрд╛ рд╣реИ рддрд╛рдХрд┐ рдИрд╡реЗрдВрдЯ рд╣реИрдВрдбрд▓рд░ рдХреЗ рдмреАрдЪ рдбреЗрдЯрд╛ рд╕реНрдерд╛рдирд╛рдВрддрд░рд┐рдд рдХрд┐рдпрд╛ рдЬрд╛ рд╕рдХреЗред рдЗрд╕ state
рд╡рд┐рд╢реЗрд╖рддрд╛ рдореЗрдВ рд╡рд░реНрддрдорд╛рди рдпреБрдЧ, рдкреБрдирд░рд╛рд╡реГрддреНрддрд┐, рдпреБрдЧреЛрдВ рдХреА рд╕рдВрдЦреНрдпрд╛ рдЖрджрд┐ рдХреЗ рдмрд╛рд░реЗ рдореЗрдВ рдмреБрдирд┐рдпрд╛рджреА рдЬрд╛рдирдХрд╛рд░реА рд╣реИред рдЗрд╕рдХрд╛ рдЙрдкрдпреЛрдЧ рдореЗрдЯреНрд░рд┐рдХреНрд╕ рдХреА рдЧрдгрдирд╛ рдХреЗ рдкрд░рд┐рдгрд╛рдореЛрдВ рд╕рд╣рд┐рдд рдХрд┐рд╕реА рднреА рдЙрдкрдпреЛрдЧрдХрд░реНрддрд╛ рдбреЗрдЯрд╛ рдХреЛ рд╕реНрдерд╛рдирд╛рдВрддрд░рд┐рдд рдХрд░рдиреЗ рдХреЗ рд▓рд┐рдП рднреА рдХрд┐рдпрд╛ рдЬрд╛ рд╕рдХрддрд╛ рд╣реИред
state = train_evaluator.run(train_eval_loader) metrics = state.metrics
рдкреНрд░рд╢рд┐рдХреНрд╖рдг рдХреЗ рджреМрд░рд╛рди рдореИрдЯреНрд░рд┐рдХреНрд╕ рдХреА рдЧрдгрдирд╛
рдпрджрд┐ рдХрд╛рд░реНрдп рдореЗрдВ рдПрдХ рдмрдбрд╝рд╛ рдкреНрд░рд╢рд┐рдХреНрд╖рдг рд╕реЗрдЯ рд╣реИ рдФрд░ рдкреНрд░рддреНрдпреЗрдХ рдкреНрд░рд╢рд┐рдХреНрд╖рдг рдпреБрдЧ рдХреЗ рдмрд╛рдж рдореИрдЯреНрд░рд┐рдХреНрд╕ рдХреА рдЧрдгрдирд╛ рдХрд░рдирд╛ рдорд╣рдВрдЧрд╛ рд╣реИ, рд▓реЗрдХрд┐рди рдлрд┐рд░ рднреА рдкреНрд░рд╢рд┐рдХреНрд╖рдг рдХреЗ рджреМрд░рд╛рди рдХреБрдЫ рдореЗрдЯреНрд░рд┐рдХреНрд╕ рдкрд░рд┐рд╡рд░реНрддрди рджреЗрдЦрдирд╛ рдЪрд╛рд╣рддреЗ рд╣реИрдВ, рддреЛ рдЖрдк рдмреЙрдХреНрд╕ рд╕реЗ рдирд┐рдореНрдирд▓рд┐рдЦрд┐рдд RunningAverage
рдЗрд╡реЗрдВрдЯ рд╣реИрдВрдбрд▓рд░ рдХрд╛ рдЙрдкрдпреЛрдЧ рдХрд░ рд╕рдХрддреЗ рд╣реИрдВред рдЙрджрд╛рд╣рд░рдг рдХреЗ рд▓рд┐рдП, рд╣рдо рд╡рд░реНрдЧреАрдХрд░рдг рдХреА рд╕рдЯреАрдХрддрд╛ рдХреА рдЧрдгрдирд╛ рдФрд░ рдкреНрд░рджрд░реНрд╢рди рдХрд░рдирд╛ рдЪрд╛рд╣рддреЗ рд╣реИрдВ:
acc_metric = RunningAverage(CategoryAccuracy(...), alpha=0.98) acc_metric.attach(trainer, 'running_avg_accuracy') @trainer.on(Events.ITERATION_COMPLETED) def log_running_avg_metrics(engine): print("running avg accuracy:", engine.state.metrics['running_avg_accuracy'])
RunningAverage
рдХрд╛рд░реНрдпрдХреНрд╖рдорддрд╛ рдХрд╛ рдЙрдкрдпреЛрдЧ рдХрд░рдиреЗ рдХреЗ рд▓рд┐рдП, рдЖрдкрдХреЛ рд╕реНрд░реЛрддреЛрдВ рд╕реЗ рдЖрдЧ рд▓рдЧрд╛рдиреЗ рдХреА рдЖрд╡рд╢реНрдпрдХрддрд╛ рд╣реИ:
pip install git+https:
рд╕реАрдЦрдиреЗ рдХреА рджрд░ рдирд┐рд░реНрдзрд╛рд░рдг
рдкреНрд░рдЬреНрд╡рд▓рд┐рдд рдХрд╛ рдЙрдкрдпреЛрдЧ рдХрд░рдХреЗ рд╕реАрдЦрдиреЗ рдХреА рдЧрддрд┐ рдХреЛ рдмрджрд▓рдиреЗ рдХреЗ рдХрдИ рддрд░реАрдХреЗ рд╣реИрдВред рдЗрд╕рдХреЗ рдмрд╛рдж, рдкреНрд░рддреНрдпреЗрдХ рдпреБрдЧ рдХреА рд╢реБрд░реБрдЖрдд рдореЗрдВ lr_scheduler.step()
рдлрд╝рдВрдХреНрд╢рди рдХреЛ рдХреЙрд▓ рдХрд░рдХреЗ рд╕рдмрд╕реЗ рд╕рд░рд▓ рд╡рд┐рдзрд┐ рдкрд░ рд╡рд┐рдЪрд╛рд░ рдХрд░реЗрдВред
from torch.optim.lr_scheduler import ExponentialLR lr_scheduler = ExponentialLR(optimizer, gamma=0.8) @trainer.on(Events.EPOCH_STARTED) def update_lr_scheduler(engine): lr_scheduler.step()
рдкреНрд░рд╢рд┐рдХреНрд╖рдг рдХреЗ рджреМрд░рд╛рди рд╕рд░реНрд╡рд╢реНрд░реЗрд╖реНрда рдореЙрдбрд▓ рдФрд░ рдЕрдиреНрдп рдорд╛рдкрджрдВрдбреЛрдВ рдХреЛ рд╕рд╣реЗрдЬрдирд╛
рдкреНрд░рд╢рд┐рдХреНрд╖рдг рдХреЗ рджреМрд░рд╛рди, рдбрд┐рд╕реНрдХ рдкрд░ рд╕рд░реНрд╡рд╢реНрд░реЗрд╖реНрда рдореЙрдбрд▓ рдХреЗ рд╡рдЬрди рдХреЛ рд░рд┐рдХреЙрд░реНрдб рдХрд░рдирд╛ рдмрд╣реБрдд рдЕрдЪреНрдЫрд╛ рд╣реЛрдЧрд╛, рд╕рд╛рде рд╣реА рд╕рдордп-рд╕рдордп рдкрд░ рдореЙрдбрд▓ рдХреА рдКрдВрдЪрд╛рдИ, рдСрдкреНрдЯрд┐рдорд╛рдЗрдЬрд╝рд░ рдкреИрд░рд╛рдореАрдЯрд░ рдФрд░ рд╕реАрдЦрдиреЗ рдХреА рдЧрддрд┐ рдХреЛ рдмрджрд▓рдиреЗ рдХреЗ рд▓рд┐рдП рдорд╛рдкрджрдВрдбреЛрдВ рдХреЛ рдмрдЪрд╛рдиреЗ рдХреЗ рд▓рд┐рдПред рдЕрдВрддрд┐рдо рдмрдЪрд╛рдП рдЧрдП рд░рд╛рдЬреНрдп рд╕реЗ рд╕реАрдЦрдиреЗ рдХреЛ рдлрд┐рд░ рд╕реЗ рд╢реБрд░реВ рдХрд░рдиреЗ рдХреЗ рд▓рд┐рдП рдЙрддреНрддрд░рд╛рд░реНрджреНрдз рдЙрдкрдпреЛрдЧреА рд╣реЛ рд╕рдХрддрд╛ рд╣реИред
рдЗрдЧреНрдирд╛рдЗрдЯ рдореЗрдВ рдЗрд╕рдХреЗ рд▓рд┐рдП рдПрдХ рд╡рд┐рд╢реЗрд╖ ModelCheckpoint
рдХреНрд▓рд╛рд╕ рд╣реИред рддреЛ, рдЪрд▓рд┐рдП рдПрдХ ModelCheckpoint
рдЗрд╡реЗрдВрдЯ ModelCheckpoint
рдмрдирд╛рддреЗ рд╣реИрдВ рдФрд░ рдЯреЗрд╕реНрдЯ рд╕реЗрдЯ рдореЗрдВ рд╕рдЯреАрдХрддрд╛ рдХреЗ рдорд╛рдорд▓реЗ рдореЗрдВ рд╕рдмрд╕реЗ рдЕрдЪреНрдЫреЗ рдореЙрдбрд▓ рдХреЛ рдмрдЪрд╛рддреЗ рд╣реИрдВред рдЗрд╕ рдорд╛рдорд▓реЗ рдореЗрдВ, рд╣рдо рдПрдХ score_function
рдлрд╝рдВрдХреНрд╢рди рдХреЛ рдкрд░рд┐рднрд╛рд╖рд┐рдд рдХрд░рддреЗ рд╣реИрдВ рдЬреЛ рдЗрд╡реЗрдВрдЯ рд╣реИрдВрдбрд▓рд░ рдХреЛ рд╕рдЯреАрдХрддрд╛ рдореВрд▓реНрдп рджреЗрддрд╛ рд╣реИ рдФрд░ рдпрд╣ рддрдп рдХрд░рддрд╛ рд╣реИ рдХрд┐ рдореЙрдбрд▓ рдХреЛ рд╕рд╣реЗрдЬрдирд╛ рд╣реИ рдпрд╛ рдирд╣реАрдВ:
from ignite.handlers import ModelCheckpoint def score_function(engine): val_avg_accuracy = engine.state.metrics['avg_accuracy'] return val_avg_accuracy best_model_saver = ModelCheckpoint("best_models", filename_prefix="model", score_name="val_accuracy", score_function=score_function, n_saved=3, save_as_state_dict=True, create_dir=True)
рдЕрдм рдкреНрд░рддреНрдпреЗрдХ 1000 рдкреБрдирд░рд╛рд╡реГрддреНрддрд┐рдпреЛрдВ рдХреЛ рд╕реАрдЦрдиреЗ рдХреА рд╕реНрдерд┐рддрд┐ рдмрдирд╛рдП рд░рдЦрдиреЗ рдХреЗ рд▓рд┐рдП рдПрдХ рдФрд░ ModelCheckpoint
рдИрд╡реЗрдВрдЯ ModelCheckpoint
рдмрдирд╛рдПрдВ:
training_saver = ModelCheckpoint("checkpoint", filename_prefix="checkpoint", save_interval=1000, n_saved=1, save_as_state_dict=True, create_dir=True) to_save = {"model": model, "optimizer": optimizer, "lr_scheduler": lr_scheduler} trainer.add_event_handler(Events.ITERATION_COMPLETED, training_saver, to_save)
рддреЛ, рд▓рдЧрднрдЧ рд╕рдм рдХреБрдЫ рддреИрдпрд╛рд░ рд╣реИ, рдЕрдВрддрд┐рдо рддрддреНрд╡ рдЬреЛрдбрд╝реЗрдВ:
рдкреНрд░рд╛рд░рдВрднрд┐рдХ рд░реЛрдХ рдкреНрд░рд╢рд┐рдХреНрд╖рдг (рдкреНрд░рд╛рд░рдВрднрд┐рдХ рд░реЛрдХ)
рдЖрдЗрдП рдПрдХ рдФрд░ рдИрд╡реЗрдВрдЯ рд╣реИрдВрдбрд▓рд░ рдЬреЛрдбрд╝реЗрдВ рдЬреЛ 10 рдпреБрдЧреЛрдВ рдореЗрдВ рдореЙрдбрд▓ рдХреА рдЧреБрдгрд╡рддреНрддрд╛ рдореЗрдВ рд╕реБрдзрд╛рд░ рдирд╣реАрдВ рд╣реЛрдиреЗ рдкрд░ рд╕реАрдЦрдирд╛ рдмрдВрдж рдХрд░ рджреЗрдЧрд╛ред рд╣рдо рдлрд┐рд░ рд╕реЗ рд╕реНрдХреЛрд░_рдлрдВрдХреНрд╢рди score_function
рдХрд╛ рдЙрдкрдпреЛрдЧ рдХрд░рдХреЗ рдореЙрдбрд▓ рдХреА рдЧреБрдгрд╡рддреНрддрд╛ рдХрд╛ рдореВрд▓реНрдпрд╛рдВрдХрди рдХрд░реЗрдВрдЧреЗред
from ignite.handlers import EarlyStopping early_stopping = EarlyStopping(patience=10, score_function=score_function, trainer=trainer) val_evaluator.add_event_handler(Events.EPOCH_COMPLETED, early_stopping)
рдкреНрд░рд╢рд┐рдХреНрд╖рдг рд╢реБрд░реВ рдХрд░реЗрдВ
рдкреНрд░рд╢рд┐рдХреНрд╖рдг рд╢реБрд░реВ рдХрд░рдиреЗ рдХреЗ рд▓рд┐рдП, рд╣рдорд╛рд░реЗ рд▓рд┐рдП run()
рд╡рд┐рдзрд┐ рдХреЛ рдХреЙрд▓ рдХрд░рдирд╛ рдкрд░реНрдпрд╛рдкреНрдд рд╣реИред рд╣рдо 10 рдпреБрдЧреЛрдВ рдХреЗ рд▓рд┐рдП рдореЙрдбрд▓ рдХреЛ рдкреНрд░рд╢рд┐рдХреНрд╖рд┐рдд рдХрд░реЗрдВрдЧреЗ:
max_epochs = 10 output = trainer.run(train_loader, max_epochs=max_epochs)
рд╕реНрдХреНрд░реАрди рдЖрдЙрдЯрдкреБрдЯ Learning rate: 0.01 Epoch[1] Iteration[50/322] Loss: 2.7984 Epoch[1] Iteration[100/322] Loss: 1.9736 Epoch[1] Iteration[150/322] Loss: 4.3419 Epoch[1] Iteration[200/322] Loss: 2.0261 Epoch[1] Iteration[250/322] Loss: 2.1724 Epoch[1] Iteration[300/322] Loss: 2.1599 Compute train metrics... Training Results - Epoch: 1 Average Loss: 1.5363 | Accuracy: 0.5177 | Precision: 0.5477 | Recall: 0.5178 Compute validation metrics... Validation Results - Epoch: 1 Average Loss: 1.5116 | Accuracy: 0.5139 | Precision: 0.5400 | Recall: 0.5140 Learning rate: 0.008 Epoch[2] Iteration[50/322] Loss: 1.4076 Epoch[2] Iteration[100/322] Loss: 1.4892 Epoch[2] Iteration[150/322] Loss: 1.2485 Epoch[2] Iteration[200/322] Loss: 1.6511 Epoch[2] Iteration[250/322] Loss: 3.3376 Epoch[2] Iteration[300/322] Loss: 1.3299 Compute train metrics... Training Results - Epoch: 2 Average Loss: 3.2686 | Accuracy: 0.1977 | Precision: 0.1792 | Recall: 0.1942 Compute validation metrics... Validation Results - Epoch: 2 Average Loss: 3.2772 | Accuracy: 0.1962 | Precision: 0.1628 | Recall: 0.1918 Learning rate: 0.006400000000000001 Epoch[3] Iteration[50/322] Loss: 0.9016 Epoch[3] Iteration[100/322] Loss: 1.2006 Epoch[3] Iteration[150/322] Loss: 0.8892 Epoch[3] Iteration[200/322] Loss: 0.8141 Epoch[3] Iteration[250/322] Loss: 1.4005 Epoch[3] Iteration[300/322] Loss: 0.8888 Compute train metrics... Training Results - Epoch: 3 Average Loss: 0.7368 | Accuracy: 0.7554 | Precision: 0.7818 | Recall: 0.7554 Compute validation metrics... Validation Results - Epoch: 3 Average Loss: 0.7177 | Accuracy: 0.7623 | Precision: 0.7863 | Recall: 0.7611 Learning rate: 0.005120000000000001 Epoch[4] Iteration[50/322] Loss: 0.8490 Epoch[4] Iteration[100/322] Loss: 0.8493 Epoch[4] Iteration[150/322] Loss: 0.8100 Epoch[4] Iteration[200/322] Loss: 0.9165 Epoch[4] Iteration[250/322] Loss: 0.9370 Epoch[4] Iteration[300/322] Loss: 0.6548 Compute train metrics... Training Results - Epoch: 4 Average Loss: 0.7047 | Accuracy: 0.7713 | Precision: 0.8040 | Recall: 0.7728 Compute validation metrics... Validation Results - Epoch: 4 Average Loss: 0.6737 | Accuracy: 0.7778 | Precision: 0.7955 | Recall: 0.7806 Learning rate: 0.004096000000000001 Epoch[5] Iteration[50/322] Loss: 0.6965 Epoch[5] Iteration[100/322] Loss: 0.6196 Epoch[5] Iteration[150/322] Loss: 0.6194 Epoch[5] Iteration[200/322] Loss: 0.3986 Epoch[5] Iteration[250/322] Loss: 0.6032 Epoch[5] Iteration[300/322] Loss: 0.7152 Compute train metrics... Training Results - Epoch: 5 Average Loss: 0.5049 | Accuracy: 0.8282 | Precision: 0.8393 | Recall: 0.8314 Compute validation metrics... Validation Results - Epoch: 5 Average Loss: 0.5084 | Accuracy: 0.8304 | Precision: 0.8386 | Recall: 0.8328 Learning rate: 0.0032768000000000007 Epoch[6] Iteration[50/322] Loss: 0.4433 Epoch[6] Iteration[100/322] Loss: 0.4764 Epoch[6] Iteration[150/322] Loss: 0.5578 Epoch[6] Iteration[200/322] Loss: 0.3684 Epoch[6] Iteration[250/322] Loss: 0.4847 Epoch[6] Iteration[300/322] Loss: 0.3811 Compute train metrics... Training Results - Epoch: 6 Average Loss: 0.4383 | Accuracy: 0.8474 | Precision: 0.8618 | Recall: 0.8495 Compute validation metrics... Validation Results - Epoch: 6 Average Loss: 0.4419 | Accuracy: 0.8446 | Precision: 0.8532 | Recall: 0.8442 Learning rate: 0.002621440000000001 Epoch[7] Iteration[50/322] Loss: 0.4447 Epoch[7] Iteration[100/322] Loss: 0.4602 Epoch[7] Iteration[150/322] Loss: 0.5345 Epoch[7] Iteration[200/322] Loss: 0.3973 Epoch[7] Iteration[250/322] Loss: 0.5023 Epoch[7] Iteration[300/322] Loss: 0.5303 Compute train metrics... Training Results - Epoch: 7 Average Loss: 0.4305 | Accuracy: 0.8579 | Precision: 0.8691 | Recall: 0.8596 Compute validation metrics... Validation Results - Epoch: 7 Average Loss: 0.4262 | Accuracy: 0.8590 | Precision: 0.8685 | Recall: 0.8606 Learning rate: 0.002097152000000001 Epoch[8] Iteration[50/322] Loss: 0.4867 Epoch[8] Iteration[100/322] Loss: 0.3090 Epoch[8] Iteration[150/322] Loss: 0.3721 Epoch[8] Iteration[200/322] Loss: 0.4559 Epoch[8] Iteration[250/322] Loss: 0.3958 Epoch[8] Iteration[300/322] Loss: 0.4222 Compute train metrics... Training Results - Epoch: 8 Average Loss: 0.3432 | Accuracy: 0.8818 | Precision: 0.8895 | Recall: 0.8817 Compute validation metrics... Validation Results - Epoch: 8 Average Loss: 0.3644 | Accuracy: 0.8713 | Precision: 0.8784 | Recall: 0.8707 Learning rate: 0.001677721600000001 Epoch[9] Iteration[50/322] Loss: 0.3557 Epoch[9] Iteration[100/322] Loss: 0.3692 Epoch[9] Iteration[150/322] Loss: 0.3510 Epoch[9] Iteration[200/322] Loss: 0.3446 Epoch[9] Iteration[250/322] Loss: 0.3966 Epoch[9] Iteration[300/322] Loss: 0.3451 Compute train metrics... Training Results - Epoch: 9 Average Loss: 0.3315 | Accuracy: 0.8954 | Precision: 0.9001 | Recall: 0.8982 Compute validation metrics... Validation Results - Epoch: 9 Average Loss: 0.3559 | Accuracy: 0.8818 | Precision: 0.8876 | Recall: 0.8847 Learning rate: 0.0013421772800000006 Epoch[10] Iteration[50/322] Loss: 0.3340 Epoch[10] Iteration[100/322] Loss: 0.3370 Epoch[10] Iteration[150/322] Loss: 0.3694 Epoch[10] Iteration[200/322] Loss: 0.3409 Epoch[10] Iteration[250/322] Loss: 0.4420 Epoch[10] Iteration[300/322] Loss: 0.2770 Compute train metrics... Training Results - Epoch: 10 Average Loss: 0.3246 | Accuracy: 0.8921 | Precision: 0.8988 | Recall: 0.8925 Compute validation metrics... Validation Results - Epoch: 10 Average Loss: 0.3536 | Accuracy: 0.8731 | Precision: 0.8785 | Recall: 0.8722
рдЕрдм рдбрд┐рд╕реНрдХ рдкрд░ рд╕рд╣реЗрдЬреЗ рдЧрдП рдореЙрдбрд▓ рдФрд░ рдорд╛рдкрджрдВрдбреЛрдВ рдХреА рдЬрд╛рдБрдЪ рдХрд░реЗрдВ:
ls best_models/ model_best_model_10_val_accuracy=0.8730994.pth model_best_model_8_val_accuracy=0.8712978.pth model_best_model_9_val_accuracy=0.8818188.pth
рдФрд░
ls checkpoint/ checkpoint_lr_scheduler_3000.pth checkpoint_optimizer_3000.pth checkpoint_model_3000.pth
рдПрдХ рдкреНрд░рд╢рд┐рдХреНрд╖рд┐рдд рдореЙрдбрд▓ рджреНрд╡рд╛рд░рд╛ рднрд╡рд┐рд╖реНрдпрд╡рд╛рдгрд┐рдпреЛрдВ
рд╕рдмрд╕реЗ рдкрд╣рд▓реЗ, рдПрдХ рдкрд░реАрдХреНрд╖рдг рдбреЗрдЯрд╛ рд▓реЛрдбрд░ рдмрдирд╛рдПрдВ (рдЙрджрд╛рд╣рд░рдг рдХреЗ рд▓рд┐рдП, рдПрдХ рд╕рддреНрдпрд╛рдкрди рдирдореВрдирд╛ рд▓реЗрдВ) рддрд╛рдХрд┐ рдбреЗрдЯрд╛ рдмреИрдЪ рдореЗрдВ рдЪрд┐рддреНрд░ рдФрд░ рдЙрдирдХреЗ рд╕реВрдЪрдХрд╛рдВрдХ рд╢рд╛рдорд┐рд▓ рд╣реЛрдВ:
class TestDataset(Dataset): def __init__(self, ds): self.ds = ds def __len__(self): return len(self.ds) def __getitem__(self, index): return self.ds[index][0], index test_dataset = TestDataset(val_dataset) test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers, drop_last=False, pin_memory="cuda" in device)
рдЗрдЧреНрдирд╛рдЗрдЯ рдХрд╛ рдЙрдкрдпреЛрдЧ рдХрд░рддреЗ рд╣реБрдП , рд╣рдо рдкрд░реАрдХреНрд╖рдг рдбреЗрдЯрд╛ рдХреЗ рд▓рд┐рдП рдПрдХ рдирдпрд╛ рдкреВрд░реНрд╡рд╛рдиреБрдорд╛рди рдЗрдВрдЬрди рдмрдирд╛рдПрдВрдЧреЗред рдРрд╕рд╛ рдХрд░рдиреЗ рдХреЗ рд▓рд┐рдП, рд╣рдо рдлрд╝рдВрдХреНрд╢рди inference_update
рдХреЛ рдкрд░рд┐рднрд╛рд╖рд┐рдд рдХрд░рддреЗ inference_update
, рдЬреЛ рдЫрд╡рд┐ рдХреА рднрд╡рд┐рд╖реНрдпрд╡рд╛рдгреА рдФрд░ рд╕реВрдЪрдХрд╛рдВрдХ рдХрд╛ рдкрд░рд┐рдгрд╛рдо рджреЗрддрд╛ рд╣реИред рд╕рдЯреАрдХрддрд╛ рдмрдврд╝рд╛рдиреЗ рдХреЗ рд▓рд┐рдП, рд╣рдо рдЬрд╛рдиреЗ-рдорд╛рдиреЗ рдЯреНрд░рд┐рдХ "рдЯреЗрд╕реНрдЯ рдЯрд╛рдЗрдо рдПрдирдЧрдарди" (TTA) рдХрд╛ рднреА рдЙрдкрдпреЛрдЧ рдХрд░реЗрдВрдЧреЗред
import torch.nn.functional as F from ignite._utils import convert_tensor def _prepare_batch(batch): x, index = batch x = convert_tensor(x, device=device) return x, index def inference_update(engine, batch): x, indices = _prepare_batch(batch) y_pred = model(x) y_pred = F.softmax(y_pred, dim=1) return {"y_pred": convert_tensor(y_pred, device='cpu'), "indices": indices} model.eval() inferencer = Engine(inference_update)
рдЕрдЧрд▓рд╛, рдИрд╡реЗрдВрдЯ рд╣реИрдВрдбрд▓рд░ рдмрдирд╛рдПрдВ рдЬреЛ рднрд╡рд┐рд╖реНрдпрд╡рд╛рдгрд┐рдпреЛрдВ рдХреЗ рдЪрд░рдг рдХреЗ рдмрд╛рд░реЗ рдореЗрдВ рд╕реВрдЪрд┐рдд рдХрд░реЗрдВрдЧреЗ рдФрд░ рднрд╡рд┐рд╖реНрдпрд╡рд╛рдгрд┐рдпреЛрдВ рдХреЛ рдПрдХ рд╕рдорд░реНрдкрд┐рдд рд╕рд░рдгреА рдореЗрдВ рд╕рд╣реЗрдЬреЗрдВрдЧреЗ:
@inferencer.on(Events.EPOCH_COMPLETED) def log_tta(engine): print("TTA {} / {}".format(engine.state.epoch, n_tta)) n_tta = 3 num_classes = 81 n_samples = len(val_dataset) # y_probas_tta = np.zeros((n_samples, num_classes, n_tta), dtype=np.float32) @inferencer.on(Events.ITERATION_COMPLETED) def save_results(engine): output = engine.state.output tta_index = engine.state.epoch - 1 start_index = ((engine.state.iteration - 1) % len(test_loader)) * batch_size end_index = min(start_index + batch_size, n_samples) batch_y_probas = output['y_pred'].detach().numpy() y_probas_tta[start_index:end_index, :, tta_index] = batch_y_probas
рдкреНрд░рдХреНрд░рд┐рдпрд╛ рд╢реБрд░реВ рдХрд░рдиреЗ рд╕реЗ рдкрд╣рд▓реЗ, рдЖрдЗрдП рд╕рдмрд╕реЗ рдЕрдЪреНрдЫрд╛ рдореЙрдбрд▓ рдбрд╛рдЙрдирд▓реЛрдб рдХрд░реЗрдВ:
model = squeezenet1_1(pretrained=False, num_classes=64) model.classifier[-1] = nn.AdaptiveAvgPool2d(1) model = model.to(device) model_state_dict = torch.load("best_models/model_best_model_10_val_accuracy=0.8730994.pth") model.load_state_dict(model_state_dict)
рд╣рдо рд▓реЙрдиреНрдЪ рдХрд░рддреЗ рд╣реИрдВ:
inferencer.run(test_loader, max_epochs=n_tta) > TTA 1 / 3 > TTA 2 / 3 > TTA 3 / 3
рдЕрдЧрд▓рд╛, рдорд╛рдирдХ рддрд░реАрдХреЗ рд╕реЗ, рд╣рдо TTA рднрд╡рд┐рд╖реНрдпрд╡рд╛рдгрд┐рдпреЛрдВ рдХрд╛ рдФрд╕рдд рд▓реЗрддреЗ рд╣реИрдВ рдФрд░ рдЙрдЪреНрдЪрддрдо рд╕рдВрднрд╛рд╡рдирд╛ рд╡рд╛рд▓реЗ рд╡рд░реНрдЧ рд╕реВрдЪрдХрд╛рдВрдХ рдХреА рдЧрдгрдирд╛ рдХрд░рддреЗ рд╣реИрдВ:
y_probas = np.mean(y_probas_tta, axis=-1) y_preds = np.argmax(y_probas, axis=-1)
рдФрд░ рдЕрдм рд╣рдо рднрд╡рд┐рд╖реНрдпрд╡рд╛рдгрд┐рдпреЛрдВ рдХреЗ рдЕрдиреБрд╕рд╛рд░ рдПрдХ рдмрд╛рд░ рдлрд┐рд░ рд╕реЗ рдореЙрдбрд▓ рдХреА рд╕рдЯреАрдХрддрд╛ рдХреА рдЧрдгрдирд╛ рдХрд░ рд╕рдХрддреЗ рд╣реИрдВ:
from sklearn.metrics import accuracy_score y_test_true = [y for _, y in val_dataset] accuracy_score(y_test_true, y_preds) > 0.9310369676443035
, , . , , , ignite .
.
github
- fast neural transfer
- reinforcement learning
- dcgan
рдирд┐рд╖реНрдХрд░реНрд╖
, ignite Facebook (. ). 0.1.0, API (Engine, State, Events, Metric, ...) . , , , pull request- github .
рдЖрдкрдХрд╛ рдзреНрдпрд╛рди рдХреЗ рд▓рд┐рдП рдзрдиреНрдпрд╡рд╛рдж!