рддрдВрддреНрд░рд┐рдХрд╛ рд╕рд╛рдзрд╛рд░рдг рд╡рд┐рднреЗрджрдХ рд╕рдореАрдХрд░рдг
рдЕрдВрддрд░ рд╕рдореАрдХрд░рдгреЛрдВ рджреНрд╡рд╛рд░рд╛ рдкреНрд░рдХреНрд░рд┐рдпрд╛рдУрдВ рдХрд╛ рдПрдХ рдорд╣рддреНрд╡рдкреВрд░реНрдг рдЕрдиреБрдкрд╛рдд рд╡рд░реНрдгрд┐рдд рд╣реИ, рдпрд╣ рд╕рдордп рдХреЗ рд╕рд╛рде рднреМрддрд┐рдХ рдкреНрд░рдгрд╛рд▓реА рдХрд╛ рд╡рд┐рдХрд╛рд╕, рд░реЛрдЧреА рдХреА рдЪрд┐рдХрд┐рддреНрд╕рд╛ рд╕реНрдерд┐рддрд┐, рд╢реЗрдпрд░ рдмрд╛рдЬрд╛рд░ рдХреА рдореВрд▓рднреВрдд рд╡рд┐рд╢реЗрд╖рддрд╛рдУрдВ рдЖрджрд┐ рд╣реЛ рд╕рдХрддрд╛ рд╣реИред рдЗрд╕ рддрд░рд╣ рдХреА рдкреНрд░рдХреНрд░рд┐рдпрд╛рдУрдВ рдкрд░ рдбреЗрдЯрд╛ рдкреНрд░рдХреГрддрд┐ рдореЗрдВ рд▓рдЧрд╛рддрд╛рд░ рдФрд░ рдирд┐рд░рдВрддрд░ рд╣реИрдВ, рдЗрд╕ рдЕрд░реНрде рдореЗрдВ рдХрд┐ рдЕрд╡рд▓реЛрдХрди рдХреЗрд╡рд▓ рдХреБрдЫ рдкреНрд░рдХрд╛рд░ рдХреЗ рдирд┐рд░рдВрддрд░ рдмрджрд▓рддреЗ рд░рд╛рдЬреНрдп рдХреА рдЕрднрд┐рд╡реНрдпрдХреНрддрд┐рдпрд╛рдВ рд╣реИрдВред
рдПрдХ рдЕрдиреНрдп рдкреНрд░рдХрд╛рд░ рдХрд╛ рд╕реАрд░рд┐рдпрд▓ рдбреЗрдЯрд╛ рднреА рд╣реИ, рдпрд╣ рдЕрд╕рддрдд рдбреЗрдЯрд╛ рд╣реИ, рдЙрджрд╛рд╣рд░рдг рдХреЗ рд▓рд┐рдП, рдПрдирдПрд▓рдкреА рдЯрд╛рд╕реНрдХ рдбреЗрдЯрд╛ред рдЗрд╕ рддрд░рд╣ рдХреЗ рдбреЗрдЯрд╛ рдореЗрдВ рд╕реНрдерд┐рддрд┐ рдЕрд▓рдЧ-рдЕрд▓рдЧ рд╣реЛрддреА рд╣реИ: рдПрдХ рдЪрд░рд┐рддреНрд░ рдпрд╛ рд╢рдмреНрдж рд╕реЗ рджреВрд╕рд░реЗ рдореЗрдВред
рдЕрдм рджреЛрдиреЛрдВ рдкреНрд░рдХрд╛рд░ рдХреЗ рдРрд╕реЗ рд╕реАрд░рд┐рдпрд▓ рдбреЗрдЯрд╛ рдХреЛ рдЖрдорддреМрд░ рдкрд░ рдкреБрдирд░рд╛рд╡рд░реНрддреА рдиреЗрдЯрд╡рд░реНрдХ рджреНрд╡рд╛рд░рд╛ рд╕рдВрд╕рд╛рдзрд┐рдд рдХрд┐рдпрд╛ рдЬрд╛рддрд╛ рд╣реИ, рд╣рд╛рд▓рд╛рдВрдХрд┐ рд╡реЗ рдкреНрд░рдХреГрддрд┐ рдореЗрдВ рднрд┐рдиреНрди рд╣реЛрддреЗ рд╣реИрдВ рдФрд░ рд╡рд┐рднрд┐рдиреНрди рджреГрд╖реНрдЯрд┐рдХреЛрдгреЛрдВ рдХреА рдЖрд╡рд╢реНрдпрдХрддрд╛ рд╣реЛрддреА рд╣реИред
рдкрд┐рдЫрд▓реЗ
рдПрдирдЖрдИрдкреАрдПрд╕ рд╕рдореНрдореЗрд▓рди рдореЗрдВ рдПрдХ рдмрд╣реБрдд рд╣реА рджрд┐рд▓рдЪрд╕реНрдк рд▓реЗрдЦ рдкреНрд░рд╕реНрддреБрдд рдХрд┐рдпрд╛ рдЧрдпрд╛ рдерд╛, рдЬреЛ рдЗрд╕ рд╕рдорд╕реНрдпрд╛ рдХреЛ рд╣рд▓ рдХрд░рдиреЗ рдореЗрдВ рдорджрдж рдХрд░ рд╕рдХрддрд╛ рд╣реИред рд▓реЗрдЦрдХ рдПрдХ рджреГрд╖реНрдЯрд┐рдХреЛрдг рдХрд╛ рдкреНрд░рд╕реНрддрд╛рд╡ рдХрд░рддреЗ рд╣реИрдВ рдЬрд┐рд╕реЗ рдЙрдиреНрд╣реЛрдВрдиреЗ
рдиреНрдпреВрд░рд▓ рдУрдбреАрдИ рдХрд╣рд╛ рд╣реИред
рдпрд╣рд╛рдБ рдореИрдВрдиреЗ рдЗрд╕ рд▓реЗрдЦ рдХреЗ рдкрд░рд┐рдгрд╛рдореЛрдВ рдХреЛ рдкреБрди: рдкреНрд░рд╕реНрддреБрдд рдХрд░рдиреЗ рдФрд░ рд╕рдВрдХреНрд╖реЗрдк рдореЗрдВ рдкреНрд░рд╕реНрддреБрдд рдХрд░рдиреЗ рдХреА рдХреЛрд╢рд┐рд╢ рдХреА рддрд╛рдХрд┐ рдЙрдирдХреЗ рд╡рд┐рдЪрд╛рд░ рд╕реЗ рдкрд░рд┐рдЪрд┐рдд рд╣реЛрдирд╛ рдЖрд╕рд╛рди рд╣реЛ рд╕рдХреЗред рдпрд╣ рдореБрдЭреЗ рд▓рдЧрддрд╛ рд╣реИ рдХрд┐ рдЗрд╕ рдирдП рдЖрд░реНрдХрд┐рдЯреЗрдХреНрдЪрд░ рдХреЛ рдЕрдЪреНрдЫреА рддрд░рд╣ рд╕реЗ рдбреЗрдЯрд╛ рд╡реИрдЬреНрдЮрд╛рдирд┐рдХ рдХреЗ рдорд╛рдирдХ рдЙрдкрдХрд░рдгреЛрдВ рдореЗрдВ рдПрдХ рдЬрдЧрд╣ рдорд┐рд▓ рд╕рдХрддреА рд╣реИ рдФрд░ рд╕рд╛рде рд╣реА рд╕рд╛рде рдЖрд╡рд░реНрддреА рдиреЗрдЯрд╡рд░реНрдХ рднреАред

рдЪрд┐рддреНрд░рд╛ 1: рдирд┐рд░рдВрддрд░ рдврд╛рд▓ backpropagation рд╕рдордп рдореЗрдВ рд╕рдВрд╡рд░реНрдзрд┐рдд рдЕрдВрддрд░ рд╕рдореАрдХрд░рдг рдХреЛ рд╣рд▓ рдХрд░рдиреЗ рдХреА рдЖрд╡рд╢реНрдпрдХрддрд╛ рд╣реИред
рддреАрд░ рдЯрд┐рдкреНрдкрдгрд┐рдпреЛрдВ рд╕реЗ рдЧреНрд░реЗрдбрд┐рдПрдВрдЯреНрд╕ рджреНрд╡рд╛рд░рд╛ рдкрд┐рдЫрдбрд╝реЗ рдкреНрд░рдЪрд╛рд░рд┐рдд рдЧреНрд░реЗрдбрд┐рдПрдВрдЯреНрд╕ рдХреЗ рд╕рдорд╛рдпреЛрдЬрди рдХрд╛ рдкреНрд░рддрд┐рдирд┐рдзрд┐рддреНрд╡ рдХрд░рддреЗ рд╣реИрдВред
рдореВрд▓ рд▓реЗрдЦ рд╕реЗ рдЪрд┐рддреНрд░рдгред
рд╕рдорд╕реНрдпрд╛ рдХрдерди
рдПрдХ рдкреНрд░рдХреНрд░рд┐рдпрд╛ рд╣реИ рдЬреЛ рдХреБрдЫ рдЕрдЬреНрдЮрд╛рдд ODE рдХрд╛ рдкрд╛рд▓рди рдХрд░рддреА рд╣реИ рдФрд░ рдкреНрд░рдХреНрд░рд┐рдпрд╛ рдХреЗ рдкреНрд░рдХреНрд╖реЗрдкрд╡рдХреНрд░ рдХреЗ рд╕рд╛рде рдХрдИ (рд╢реЛрд░) рдЕрд╡рд▓реЛрдХрди рдХрд░рддреЗ рд╣реИрдВ
рдХреИрд╕реЗ рдПрдХ рдЕрдиреБрдорд╛рди рд▓рдЧрд╛рдиреЗ рдХреЗ рд▓рд┐рдП

рд╕реНрдкреАрдХрд░ рдХреЗ рдХрд╛рд░реНрдп

?
рд╕рдмрд╕реЗ рдкрд╣рд▓реЗ, рдПрдХ рд╕рд░рд▓ рдХрд╛рд░реНрдп рдкрд░ рд╡рд┐рдЪрд╛рд░ рдХрд░реЗрдВ: рдХреЗрд╡рд▓ 2 рдЕрд╡рд▓реЛрдХрди рд╣реИрдВ, рд╢реБрд░реБрдЖрдд рдореЗрдВ рдФрд░ рдкреНрд░рдХреНрд╖реЗрдкрд╡рдХреНрд░ рдХреЗ рдЕрдВрдд рдореЗрдВ,

ред
рд╕рд┐рд╕реНрдЯрдо рдХрд╛ рд╡рд┐рдХрд╛рд╕ рд░рд╛рдЬреНрдп рд╕реЗ рд╢реБрд░реВ рд╣реЛрддрд╛ рд╣реИ

рд╕рдордп рдкрд░

ODE рд╕рд┐рд╕реНрдЯрдо рдХреЗ рд╡рд┐рдХрд╛рд╕ рдХреЗ рдХрд┐рд╕реА рднреА рддрд░реАрдХреЗ рдХрд╛ рдЙрдкрдпреЛрдЧ рдХрд░рдХреЗ рдХреБрдЫ рдкреИрд░рд╛рдореАрдЯрд░ рд╡рд╛рд▓реЗ рдбрд╛рдпрдирд╛рдорд┐рдХреНрд╕ рдлрд╝рдВрдХреНрд╢рди рдХреЗ рд╕рд╛рдеред рд╕рд┐рд╕реНрдЯрдо рдХреЗ рдмрд╛рдж рдПрдХ рдирдП рд░рд╛рдЬреНрдп рдореЗрдВ рд╣реИ

, рдпрд╣ рд░рд╛рдЬреНрдп рдХреЗ рд╕рд╛рде рддреБрд▓рдирд╛ рдХреА рдЬрд╛рддреА рд╣реИ

рдФрд░ рдЙрдирдХреЗ рдмреАрдЪ рдХрд╛ рдЕрдВрддрд░ рдорд╛рдирдХреЛрдВ рдХреЛ рдЕрд▓рдЧ рдХрд░рдХреЗ рдХрдо рд╕реЗ рдХрдо рдХрд┐рдпрд╛ рдЬрд╛рддрд╛ рд╣реИ

рдЧрддрд┐рд╢реАрд▓рддрд╛ рдХрд╛рд░реНрдп рдХрд░рддрд╛ рд╣реИред
рдпрд╛, рдФрдкрдЪрд╛рд░рд┐рдХ рд░реВрдк рд╕реЗ, рдиреБрдХрд╕рд╛рди рдлрд╝рдВрдХреНрд╢рди рдХреЛ рдХрдо рдХрд░рдиреЗ рдкрд░ рд╡рд┐рдЪрд╛рд░ рдХрд░реЗрдВ

:
рдХреЛ рдХрдо рдХрд░рдиреЗ рдХреЗ рд▓рд┐рдП

, рдЖрдкрдХреЛ рдЗрд╕рдХреЗ рд╕рднреА рдорд╛рдкрджрдВрдбреЛрдВ рдХреЗ рд▓рд┐рдП рдЧреНрд░реЗрдбрд┐рдПрдВрдЯ рдХреА рдЧрдгрдирд╛ рдХрд░рдиреЗ рдХреА рдЖрд╡рд╢реНрдпрдХрддрд╛ рд╣реИ:

ред рдРрд╕рд╛ рдХрд░рдиреЗ рдХреЗ рд▓рд┐рдП, рдЖрдкрдХреЛ рдкрд╣рд▓реЗ рдпрд╣ рдирд┐рд░реНрдзрд╛рд░рд┐рдд рдХрд░рдиреЗ рдХреА рдЖрд╡рд╢реНрдпрдХрддрд╛ рд╣реИ рдХрд┐ рдХреИрд╕реЗ

рд╣рд░ рдкрд▓ рд░рд╛рдЬреНрдп рдкрд░ рдирд┐рд░реНрднрд░ рдХрд░рддрд╛ рд╣реИ

:

рдПрдХ
рд╕рд╣рд╛рдпрдХ рд░рд╛рдЬреНрдп рдХрд╣рд╛ рдЬрд╛рддрд╛ рд╣реИ, рдЗрд╕рдХреА рдЧрддрд┐рд╢реАрд▓рддрд╛ рдПрдХ рдЕрдиреНрдп рдЕрдВрддрд░ рд╕рдореАрдХрд░рдг рджреНрд╡рд╛рд░рд╛ рджреА рдЧрдИ рд╣реИ, рдЬрд┐рд╕реЗ рдПрдХ рдЬрдЯрд┐рд▓ рдлрд╝рдВрдХреНрд╢рди (
рд╢реНрд░реГрдВрдЦрд▓рд╛ рдирд┐рдпрдо ) рдХреЗ рднреЗрджрднрд╛рд╡ рдХреЗ рдПрдХ рдирд┐рд░рдВрддрд░ рдПрдирд╛рд▓реЙрдЧ рдорд╛рдирд╛ рдЬрд╛ рд╕рдХрддрд╛ рд╣реИ:
рдЗрд╕ рд╕реВрддреНрд░ рдХрд╛ рдЖрдЙрдЯрдкреБрдЯ рдореВрд▓ рд▓реЗрдЦ рдХреЗ рдкрд░рд┐рд╢рд┐рд╖реНрдЯ рдореЗрдВ рдкрд╛рдпрд╛ рдЬрд╛ рд╕рдХрддрд╛ рд╣реИред
рдЗрд╕ рдЖрд▓реЗрдЦ рдореЗрдВ рд╡реИрдХреНрдЯрд░ рдХреЛ рд▓реЛрдЕрд░рдХреЗрд╕ рд╡реИрдХреНрдЯрд░ рдорд╛рдирд╛ рдЬрд╛рдирд╛ рдЪрд╛рд╣рд┐рдП, рд╣рд╛рд▓рд╛рдВрдХрд┐ рдореВрд▓ рд▓реЗрдЦ рдПрдХ рдкрдВрдХреНрддрд┐ рдФрд░ рд╕реНрддрдВрдн рдкреНрд░рддрд┐рдирд┐рдзрд┐рддреНрд╡ рджреЛрдиреЛрдВ рдХрд╛ рдЙрдкрдпреЛрдЧ рдХрд░рддрд╛ рд╣реИредрд╕рдордп рдореЗрдВ рдЕрдВрддрд░ (4) рдХреЛ рд╣рд▓ рдХрд░рддреЗ рд╣реБрдП, рд╣рдо рдкреНрд░рд╛рд░рдВрднрд┐рдХ рдЕрд╡рд╕реНрдерд╛ рдкрд░ рдирд┐рд░реНрднрд░рддрд╛ рдкреНрд░рд╛рдкреНрдд рдХрд░рддреЗ рд╣реИрдВ

:
рдХреЗ рд╕рдВрдмрдВрдз рдореЗрдВ рдврд╛рд▓ рдХреА рдЧрдгрдирд╛ рдХрд░рдиреЗ рдХреЗ рд▓рд┐рдП

рдФрд░

, рдЖрдк рдмрд╕ рдЙрдиреНрд╣реЗрдВ рд░рд╛рдЬреНрдп рдХрд╛ рд╣рд┐рд╕реНрд╕рд╛ рдорд╛рди рд╕рдХрддреЗ рд╣реИрдВред рдЗрд╕ рд╕реНрдерд┐рддрд┐ рдХреЛ
рд╕рдВрд╡рд░реНрдзрд┐рдд рдХрд╣рд╛ рдЬрд╛рддрд╛ рд╣реИред рдЗрд╕ рд░рд╛рдЬреНрдп рдХреА рдЧрддрд┐рд╢реАрд▓рддрд╛ рдореВрд▓ рдЧрддрд┐рд╢реАрд▓рддрд╛ рд╕реЗ рддреБрдЪреНрдЫ рд░реВрдк рд╕реЗ рдкреНрд░рд╛рдкреНрдд рд╣реЛрддреА рд╣реИ:
рдлрд┐рд░ рдЗрд╕ рд╕рдВрд╡рд░реНрдзрд┐рдд рд░рд╛рдЬреНрдп рдХреЗ рд▓рд┐рдП рд╕рдВрдпреБрдЧреНрдо рд░рд╛рдЬреНрдп:
рдзреАрд░реЗ рд╕рдВрд╡рд░реНрдзрд┐рдд рдЧрддрд┐рд╢реАрд▓рддрд╛:
рд╕рдВрдпреБрдЧреНрдорд┐рдд рд╕рдВрд╡рд░реНрдзрд┐рдд рдЕрд╡рд╕реНрдерд╛ рдХреЗ рдЕрдВрддрд░ рд╕рдореАрдХрд░рдг рд╕реВрддреНрд░ рд╕реЗ (4) рддрдм:
рд╕рдордп рд╕реАрдорд╛ рдореЗрдВ рдЗрд╕ ODE рдХреЛ рд╣рд▓ рдХрд░рдирд╛:
рдХрд┐рд╕рдХреЗ рд╕рд╛рде рд╣реИ
рд╕рднреА рдЗрдирдкреБрдЯ рдкреИрд░рд╛рдореАрдЯрд░реНрд╕ рдореЗрдВ рдЧреНрд░реЗрдбрд┐рдПрдВрдЯреНрд╕ рдХреЛ
OESESolve ODE
solver рджреЗрддрд╛ рд╣реИ ред
рд╕рднреА рдЧреНрд░реЗрдбрд┐рдПрдВрдЯ (резреж), (резрез), (резреи), (резрей) рдХреЛ рд╕рдВрдпреБрдЧреНрдорд┐рдд рд╕рдВрд╡рд░реНрдзрд┐рдд рдЕрд╡рд╕реНрдерд╛ (реп) рдХреА рдЧрддрд┐рдХреА рдХреЗ рд╕рд╛рде рдПрдХ
рдУрдбреАрд╕реЙрд▓реНрд╡ рдХреЙрд▓ рдореЗрдВ рдПрдХ рд╕рд╛рде рдЧрдгрдирд╛ рдХреА рдЬрд╛ рд╕рдХрддреА рд╣реИред
рдореВрд▓ рд▓реЗрдЦ рд╕реЗ рдЪрд┐рддреНрд░рдгредрдКрдкрд░ рджрд┐рдП рдЧрдП рдПрд▓реНрдЧреЛрд░рд┐рджрдо рдХреНрд░рдорд┐рдХ рдЯрд┐рдкреНрдкрдгрд┐рдпреЛрдВ рдХреЗ рд▓рд┐рдП ODE рд╕рдорд╛рдзрд╛рди рдХреЗ рдврд╛рд▓ рдХреЗ рд░рд┐рд╡рд░реНрд╕ рдкреНрд░рд╕рд╛рд░ рдХрд╛ рд╡рд░реНрдгрди рдХрд░рддрд╛ рд╣реИред
рдПрдХ рдкреНрд░рдХреНрд╖реЗрдкрд╡рдХреНрд░ рдкрд░ рдХрдИ рдЯрд┐рдкреНрдкрдгрд┐рдпреЛрдВ рдХреЗ рдорд╛рдорд▓реЗ рдореЗрдВ, рд╕рдм рдХреБрдЫ рдЙрд╕реА рддрд░рд╣ рд╕реЗ рдЧрдгрдирд╛ рдХреА рдЬрд╛рддреА рд╣реИ, рд▓реЗрдХрд┐рди рдЯрд┐рдкреНрдкрдгрд┐рдпреЛрдВ рдХреЗ рдХреНрд╖рдгреЛрдВ рдореЗрдВ, рдкреНрд░рдЪрд╛рд░рд┐рдд рдврд╛рд▓ рдХреЗ рд╡реНрдпреБрддреНрдХреНрд░рдо рдХреЛ рд╡рд░реНрддрдорд╛рди рдЕрд╡рд▓реЛрдХрди рд╕реЗ рдЧреНрд░реЗрдбрд┐рдПрдВрдЯ рдХреЗ рд╕рд╛рде рд╕рдорд╛рдпреЛрдЬрд┐рдд рдХрд┐рдпрд╛ рдЬрд╛рдирд╛ рдЪрд╛рд╣рд┐рдП, рдЬреИрд╕рд╛ рдХрд┐
рдЪрд┐рддреНрд░ 1 рдореЗрдВ рджрд┐рдЦрд╛рдпрд╛ рдЧрдпрд╛ рд╣реИред
рдХрд╛рд░реНрдпрд╛рдиреНрд╡рдпрди
рдиреАрдЪреЗ рджрд┐рдП рдЧрдП рдХреЛрдб рдореЗрд░реЗ
рддрдВрддреНрд░рд┐рдХрд╛ рдиреНрдпреВрд░реЙрдЬ рдХрд╛ рдХрд╛рд░реНрдпрд╛рдиреНрд╡рдпрди рд╣реИред рдореИрдВрдиреЗ рд╡рд┐рд╢реБрджреНрдз рд░реВрдк рд╕реЗ рдпрд╣ рд╕рдордЭрдиреЗ рдХреЗ рд▓рд┐рдП рдХрд┐рдпрд╛ рдХрд┐ рдХреНрдпрд╛ рд╣реЛ рд░рд╣рд╛ рд╣реИред рд╣рд╛рд▓рд╛рдВрдХрд┐, рдпрд╣ рд▓реЗрдЦ рдХреЗ рд▓реЗрдЦрдХреЛрдВ рдХреЗ
рднрдВрдбрд╛рд░ рдореЗрдВ рд▓рд╛рдЧреВ рд╣реЛрдиреЗ рдХреЗ рдмрд╣реБрдд рдХрд░реАрдм рд╣реИред рдЗрд╕рдореЗрдВ рд╡рд╣ рд╕рднреА рдХреЛрдб рд╣реЛрддреЗ рд╣реИрдВ, рдЬрд┐рдиреНрд╣реЗрдВ рдЖрдкрдХреЛ рдПрдХ рд╕реНрдерд╛рди рдкрд░ рд╕рдордЭрдиреЗ рдХреА рдЖрд╡рд╢реНрдпрдХрддрд╛ рд╣реЛрддреА рд╣реИ, рдпрд╣ рдереЛрдбрд╝рд╛ рдФрд░ рдЕрдзрд┐рдХ рдЯрд┐рдкреНрдкрдгреА рдХрд░рддрд╛ рд╣реИред рд╡рд╛рд╕реНрддрд╡рд┐рдХ рдЕрдиреБрдкреНрд░рдпреЛрдЧреЛрдВ рдФрд░ рдкреНрд░рдпреЛрдЧреЛрдВ рдХреЗ рд▓рд┐рдП, рдореВрд▓ рд▓реЗрдЦ рдХреЗ рд▓реЗрдЦрдХреЛрдВ рдХреЗ рдХрд╛рд░реНрдпрд╛рдиреНрд╡рдпрди рдХрд╛ рдЙрдкрдпреЛрдЧ рдХрд░рдирд╛ рдЕрднреА рднреА рдмреЗрд╣рддрд░ рд╣реИред
import math import numpy as np from IPython.display import clear_output from tqdm import tqdm_notebook as tqdm import matplotlib as mpl import matplotlib.pyplot as plt %matplotlib inline import seaborn as sns sns.color_palette("bright") import matplotlib as mpl import matplotlib.cm as cm import torch from torch import Tensor from torch import nn from torch.nn import functional as F from torch.autograd import Variable use_cuda = torch.cuda.is_available()
рд╕рдмрд╕реЗ рдкрд╣рд▓реЗ рдЖрдкрдХреЛ ODE рд╕рд┐рд╕реНрдЯрдо рдХреЗ рд╡рд┐рдХрд╛рд╕ рдХреЗ рд▓рд┐рдП рдХрд┐рд╕реА рднреА рд╡рд┐рдзрд┐ рдХреЛ рд▓рд╛рдЧреВ рдХрд░рдиреЗ рдХреА рдЖрд╡рд╢реНрдпрдХрддрд╛ рд╣реИред рд╕рд╛рджрдЧреА рдХреЗ рд▓рд┐рдП, рдпреВрд▓рд░ рдкрджреНрдзрддрд┐ рдХреЛ рдпрд╣рд╛рдВ рд▓рд╛рдЧреВ рдХрд┐рдпрд╛ рдЧрдпрд╛ рд╣реИ, рд╣рд╛рд▓рд╛рдВрдХрд┐ рдХреЛрдИ рднреА рд╕реНрдкрд╖реНрдЯ рдпрд╛ рдирд┐рд╣рд┐рдд рд╡рд┐рдзрд┐ рдЙрдкрдпреБрдХреНрдд рд╣реИред
def ode_solve(z0, t0, t1, f): """ - """ h_max = 0.05 n_steps = math.ceil((abs(t1 - t0)/h_max).max().item()) h = (t1 - t0)/n_steps t = t0 z = z0 for i_step in range(n_steps): z = z + h * f(z, t) t = t + h return z
рдпрд╣ рдЙрдкрдпреЛрдЧреА рддрд░реАрдХреЛрдВ рдХреЗ рдПрдХ рдЬреЛрдбрд╝реЗ рдХреЗ рд╕рд╛рде рдПрдХ рдорд╛рдирдХреАрдХреГрдд рдбрд╛рдпрдирд╛рдорд┐рдХреНрд╕ рдлрд╝рдВрдХреНрд╢рди рдХреЗ рд╕реБрдкрд░рдХреНрд▓рд╛рд╕ рдХрд╛ рднреА рд╡рд░реНрдгрди рдХрд░рддрд╛ рд╣реИред
рдкрд╣рд▓рд╛: рдЖрдкрдХреЛ рдЙрди рд╕рднреА рдорд╛рдкрджрдВрдбреЛрдВ рдХреЛ рд╡рд╛рдкрд╕ рдХрд░рдиреЗ рдХреА рдЖрд╡рд╢реНрдпрдХрддрд╛ рд╣реИ рдЬрд┐рди рдкрд░ рдлрд╝рдВрдХреНрд╢рди рд╡реЗрдХреНрдЯрд░ рдХреЗ рд░реВрдк рдореЗрдВ рдирд┐рд░реНрднрд░ рдХрд░рддрд╛ рд╣реИред
рджреВрд╕рд░реА рдмрд╛рдд: рд╕рдВрд╡рд░реНрдзрд┐рдд рдЧрддрд┐рд╢реАрд▓рддрд╛ рдХреА рдЧрдгрдирд╛ рдХрд░рдирд╛ рдЖрд╡рд╢реНрдпрдХ рд╣реИред рдпрд╣ рдЧрддрд┐рд╢реАрд▓рддрд╛ рдорд╛рдкрджрдВрдбреЛрдВ рдФрд░ рдЗрдирдкреБрдЯ рдбреЗрдЯрд╛ рдХреЗ рд╕рдВрджрд░реНрдн рдореЗрдВ рдкреИрд░рд╛рдореАрдЯрд░ рдХрд┐рдП рдЧрдП рдлрд╝рдВрдХреНрд╢рди рдХреЗ рдЧреНрд░реЗрдбрд┐рдПрдВрдЯ рдкрд░ рдирд┐рд░реНрднрд░ рдХрд░рддреА рд╣реИред рдкреНрд░рддреНрдпреЗрдХ рдирдИ рд╡рд╛рд╕реНрддреБрдХрд▓рд╛ рдХреЗ рд▓рд┐рдП рдкреНрд░рддреНрдпреЗрдХ рд╣рд╛рде рд╕реЗ рдврд╛рд▓ рдХреЛ рдкрдВрдЬреАрдХреГрдд рдирд╣реАрдВ рдХрд░рдиреЗ рдХреЗ рд▓рд┐рдП, рд╣рдо
torch.autograd.grad рд╡рд┐рдзрд┐ рдХрд╛ рдЙрдкрдпреЛрдЧ рдХрд░реЗрдВрдЧреЗред
class ODEF(nn.Module): def forward_with_grad(self, z, t, grad_outputs): """Compute f and a df/dz, a df/dp, a df/dt""" batch_size = z.shape[0] out = self.forward(z, t) a = grad_outputs adfdz, adfdt, *adfdp = torch.autograd.grad( (out,), (z, t) + tuple(self.parameters()), grad_outputs=(a), allow_unused=True, retain_graph=True )
рдиреАрдЪреЗ рджрд┐рдП рдЧрдП рдХреЛрдб рдореЗрдВ
рддрдВрддреНрд░рд┐рдХрд╛ ODE рдХреЗ рд▓рд┐рдП рдЖрдЧреЗ рдФрд░ рдкреАрдЫреЗ рдХреЗ рдкреНрд░рд╕рд╛рд░ рдХрд╛ рд╡рд░реНрдгрди рд╣реИред рдЖрдкрдХреЛ рдЗрд╕ рдХреЛрдб рдХреЛ рдореБрдЦреНрдп
torch.nn.Module рд╕реЗ
torch.autograd.Function рдлрд╝рдВрдХреНрд╢рди рдХреЗ рд░реВрдк рдореЗрдВ
рдЕрд▓рдЧ рдХрд░рдирд╛ рд╣реЛрдЧрд╛ рдХреНрдпреЛрдВрдХрд┐ рдЙрддреНрддрд░рд╛рд░реНрджреНрдз рдореЗрдВ рдЖрдк рдПрдХ рдореЙрдбреНрдпреВрд▓ рдХреЗ рд╡рд┐рдкрд░реАрдд, рдПрдХ рдордирдорд╛рдирд╛
backpropagation рд╡рд┐рдзрд┐ рдХреЛ рд▓рд╛рдЧреВ рдХрд░ рд╕рдХрддреЗ рд╣реИрдВред рддреЛ рдпрд╣ рд╕рд┐рд░реНрдл рдПрдХ рдмреИрд╕рд╛рдЦреА рд╣реИред
рдпрд╣ рд╕реБрд╡рд┐рдзрд╛ рд╕рдВрдкреВрд░реНрдг
рддрдВрддреНрд░рд┐рдХрд╛ ODE рджреГрд╖реНрдЯрд┐рдХреЛрдг рдХреЛ рд░реЗрдЦрд╛рдВрдХрд┐рдд рдХрд░рддреА рд╣реИред
class ODEAdjoint(torch.autograd.Function): @staticmethod def forward(ctx, z0, t, flat_parameters, func): assert isinstance(func, ODEF) bs, *z_shape = z0.size() time_len = t.size(0) with torch.no_grad(): z = torch.zeros(time_len, bs, *z_shape).to(z0) z[0] = z0 for i_t in range(time_len - 1): z0 = ode_solve(z0, t[i_t], t[i_t+1], func) z[i_t+1] = z0 ctx.func = func ctx.save_for_backward(t, z.clone(), flat_parameters) return z @staticmethod def backward(ctx, dLdz): """ dLdz shape: time_len, batch_size, *z_shape """ func = ctx.func t, z, flat_parameters = ctx.saved_tensors time_len, bs, *z_shape = z.size() n_dim = np.prod(z_shape) n_params = flat_parameters.size(0)
рдЕрдм рд╕реБрд╡рд┐рдзрд╛ рдХреЗ рд▓рд┐рдП, рдЗрд╕ рдлрд╝рдВрдХреНрд╢рди рдХреЛ
nn.Module рдореЗрдВ
рд▓рдкреЗрдЯреЗрдВ ред
class NeuralODE(nn.Module): def __init__(self, func): super(NeuralODE, self).__init__() assert isinstance(func, ODEF) self.func = func def forward(self, z0, t=Tensor([0., 1.]), return_whole_sequence=False): t = t.to(z0) z = ODEAdjoint.apply(z0, t, self.func.flatten_parameters(), self.func) if return_whole_sequence: return z else: return z[-1]
рдЖрд╡реЗрджрди
рд╡рд╛рд╕реНрддрд╡рд┐рдХ рдЧрддрд┐рд╢реАрд▓рддрд╛ рд╕рдорд╛рд░реЛрд╣ рдХреА рдкреБрдирд░реНрдкреНрд░рд╛рдкреНрддрд┐ (рджреГрд╖реНрдЯрд┐рдХреЛрдг рд╕рддреНрдпрд╛рдкрди)
рдПрдХ рдмреБрдирд┐рдпрд╛рджреА рдкрд░реАрдХреНрд╖рдг рдХреЗ рд░реВрдк рдореЗрдВ, рдЖрдЗрдП рдЕрдм рджреЗрдЦреЗрдВ рдХрд┐ рдХреНрдпрд╛
рддрдВрддреНрд░рд┐рдХрд╛ ODE рдЕрд╡рд▓реЛрдХрди рд╕рдВрдмрдВрдзреА рдбреЗрдЯрд╛ рдХрд╛ рдЙрдкрдпреЛрдЧ рдХрд░рдХреЗ рдЧрддрд┐рд╢реАрд▓рддрд╛ рдХреЗ рд╡рд╛рд╕реНрддрд╡рд┐рдХ рдХрд╛рд░реНрдп рдХреЛ рдкреБрдирд░реНрд╕реНрдерд╛рдкрд┐рдд рдХрд░ рд╕рдХрддрд╛ рд╣реИред
рдРрд╕рд╛ рдХрд░рдиреЗ рдХреЗ рд▓рд┐рдП, рд╣рдо рдкрд╣рд▓реЗ ODE рдХреЗ рдбрд╛рдпрдирд╛рдорд┐рдХреНрд╕ рдлрд╝рдВрдХреНрд╢рди рдХреЛ рдирд┐рд░реНрдзрд╛рд░рд┐рдд рдХрд░рддреЗ рд╣реИрдВ, рдЗрд╕рдХреЗ рдЖрдзрд╛рд░ рдкрд░ рдкреНрд░рдХреНрд╖реЗрдкрд╡рдХреНрд░ рдХреЛ рд╡рд┐рдХрд╕рд┐рдд рдХрд░рддреЗ рд╣реИрдВ, рдФрд░ рдлрд┐рд░ рдЗрд╕реЗ рдмреЗрддрд░рддреАрдм рдврдВрдЧ рд╕реЗ рдкреИрд░рд╛рдореАрдЯрд░ рдХрд┐рдП рдЧрдП рдбрд╛рдпрдиреЗрдорд┐рдХреНрд╕ рдлрд╝рдВрдХреНрд╢рди рд╕реЗ рдкреБрдирд░реНрд╕реНрдерд╛рдкрд┐рдд рдХрд░рдиреЗ рдХрд╛ рдкреНрд░рдпрд╛рд╕ рдХрд░рддреЗ рд╣реИрдВред
рд╕рдмрд╕реЗ рдкрд╣рд▓реЗ, рд╣рдо рдПрдХ рд░реИрдЦрд┐рдХ ODE рдХреЗ рд╕рдмрд╕реЗ рд╕рд░рд▓ рдорд╛рдорд▓реЗ рдХреА рдЬрд╛рдВрдЪ рдХрд░рддреЗ рд╣реИрдВред рдЧрддрд┐рдХреА рдХрд╛ рдХрд╛рд░реНрдп рдХреЗрд╡рд▓ рдПрдХ рдореИрдЯреНрд░рд┐рдХреНрд╕ рдХреА рдХреНрд░рд┐рдпрд╛ рд╣реИред
рдПрдХ рдпрд╛рджреГрдЪреНрдЫрд┐рдХ рдореИрдЯреНрд░рд┐рдХреНрд╕ рджреНрд╡рд╛рд░рд╛ рдкреНрд░рд╢рд┐рдХреНрд╖рд┐рдд рдлрд╝рдВрдХреНрд╢рди рдкреИрд░рд╛рдЯреНрд░рд╛рдЗрдЬреНрдб рд╣реИред
рдЗрд╕рдХреЗ рдЕрд▓рд╛рд╡рд╛, рдереЛрдбрд╝рд╛ рдФрд░ рдЕрдзрд┐рдХ рдкрд░рд┐рд╖реНрдХреГрдд рдЧрддрд┐рд╢реАрд▓рддрд╛ (рдЬрд┐рдл рдХреЗ рдмрд┐рдирд╛, рдХреНрдпреЛрдВрдХрд┐ рд╕реАрдЦрдиреЗ рдХреА рдкреНрд░рдХреНрд░рд┐рдпрд╛ рдЗрддрдиреА рд╕реБрдВрджрд░ рдирд╣реАрдВ рд╣реИ :))
рдпрд╣рд╛рдВ рд╕реАрдЦрдиреЗ рдХрд╛ рдХрд╛рд░реНрдп рдПрдХ рдЫрд┐рдкреА рд╣реБрдИ рдкрд░рдд рдХреЗ рд╕рд╛рде рдкреВрд░реА рддрд░рд╣ рд╕реЗ рдЬреБрдбрд╝рд╛ рд╣реБрдЖ рдиреЗрдЯрд╡рд░реНрдХ рд╣реИред

рдХреЛрдб class LinearODEF(ODEF): def __init__(self, W): super(LinearODEF, self).__init__() self.lin = nn.Linear(2, 2, bias=False) self.lin.weight = nn.Parameter(W) def forward(self, x, t): return self.lin(x)
рдбрд╛рдпрдирд╛рдорд┐рдХреНрд╕ рдлрд╝рдВрдХреНрд╢рди рдХреЗрд╡рд▓ рдПрдХ рдореИрдЯреНрд░рд┐рдХреНрд╕ рд╣реИ
class SpiralFunctionExample(LinearODEF): def __init__(self): matrix = Tensor([[-0.1, -1.], [1., -0.1]]) super(SpiralFunctionExample, self).__init__(matrix)
рдмреЗрддрд░рддреАрдм рдврдВрдЧ рд╕реЗ рдкреИрд░рд╛рдореАрдЯрд░ рдореИрдЯреНрд░рд┐рдХреНрд╕
class RandomLinearODEF(LinearODEF): def __init__(self): super(RandomLinearODEF, self).__init__(torch.randn(2, 2)/2.)
рдЕрдзрд┐рдХ рдкрд░рд┐рд╖реНрдХреГрдд рдкреНрд░рдХреНрд╖реЗрдкрд╡рдХреНрд░ рдХреЗ рд▓рд┐рдП рдЧрддрд┐рд╢реАрд▓рддрд╛
class TestODEF(ODEF): def __init__(self, A, B, x0): super(TestODEF, self).__init__() self.A = nn.Linear(2, 2, bias=False) self.A.weight = nn.Parameter(A) self.B = nn.Linear(2, 2, bias=False) self.B.weight = nn.Parameter(B) self.x0 = nn.Parameter(x0) def forward(self, x, t): xTx0 = torch.sum(x*self.x0, dim=1) dxdt = torch.sigmoid(xTx0) * self.A(x - self.x0) + torch.sigmoid(-xTx0) * self.B(x + self.x0) return dxdt
рдПрдХ рдкреВрд░реА рддрд░рд╣ рд╕реЗ рдЬреБрдбрд╝реЗ рдиреЗрдЯрд╡рд░реНрдХ рдХреЗ рд░реВрдк рдореЗрдВ рдЧрддрд┐рд╢реАрд▓рддрд╛ рд╕реАрдЦрдирд╛
class NNODEF(ODEF): def __init__(self, in_dim, hid_dim, time_invariant=False): super(NNODEF, self).__init__() self.time_invariant = time_invariant if time_invariant: self.lin1 = nn.Linear(in_dim, hid_dim) else: self.lin1 = nn.Linear(in_dim+1, hid_dim) self.lin2 = nn.Linear(hid_dim, hid_dim) self.lin3 = nn.Linear(hid_dim, in_dim) self.elu = nn.ELU(inplace=True) def forward(self, x, t): if not self.time_invariant: x = torch.cat((x, t), dim=-1) h = self.elu(self.lin1(x)) h = self.elu(self.lin2(h)) out = self.lin3(h) return out def to_np(x): return x.detach().cpu().numpy() def plot_trajectories(obs=None, times=None, trajs=None, save=None, figsize=(16, 8)): plt.figure(figsize=figsize) if obs is not None: if times is None: times = [None] * len(obs) for o, t in zip(obs, times): o, t = to_np(o), to_np(t) for b_i in range(o.shape[1]): plt.scatter(o[:, b_i, 0], o[:, b_i, 1], c=t[:, b_i, 0], cmap=cm.plasma) if trajs is not None: for z in trajs: z = to_np(z) plt.plot(z[:, 0, 0], z[:, 0, 1], lw=1.5) if save is not None: plt.savefig(save) plt.show() def conduct_experiment(ode_true, ode_trained, n_steps, name, plot_freq=10):
рдЬреИрд╕рд╛ рдХрд┐ рдЖрдк рджреЗрдЦ рд╕рдХрддреЗ рд╣реИрдВ,
рддрдВрддреНрд░рд┐рдХрд╛ ODE рдЧрддрд┐рд╢реАрд▓рддрд╛ рдХреЛ рдмрд╣рд╛рд▓ рдХрд░рдиреЗ рдХрд╛ рдПрдХ рдмрд╣реБрдд рдЕрдЪреНрдЫрд╛ рдХрд╛рдо рдХрд░рддрд╛ рд╣реИред рдпрд╣реА рд╣реИ, рдЕрд╡рдзрд╛рд░рдгрд╛ рдПрдХ рдкреВрд░реЗ рдХрд╛рдо рдХрд░рддрд╛ рд╣реИред
рдЕрдм рдереЛрдбрд╝реА рдФрд░ рдЬрдЯрд┐рд▓ рд╕рдорд╕реНрдпрд╛ (MNIST, haha) рдкрд░ рдЬрд╛рдБрдЪ рдХрд░реЗрдВред
рддрдВрддреНрд░рд┐рдХрд╛ ODE ResNets рд╕реЗ рдкреНрд░реЗрд░рд┐рдд рд╣реИ
ResNet'ax рдореЗрдВ, рд╕реВрддреНрд░ рдХреЗ рдЕрдиреБрд╕рд╛рд░ рдЫрд┐рдкреА рд╣реБрдИ рд╕реНрдерд┐рддрд┐ рдмрджрд▓ рдЬрд╛рддреА рд╣реИ
рдЬрд╣рд╛рдБ

рдмреНрд▓реЙрдХ рдирдВрдмрд░ рд╣реИ рдФрд░

рдпрд╣ рдПрдХ рдлрд╝рдВрдХреНрд╢рди рд╣реИ рдЬрд┐рд╕реЗ рдмреНрд▓реЙрдХ рдХреЗ рдЕрдВрджрд░ рдХреА рдкрд░рддреЛрдВ рджреНрд╡рд╛рд░рд╛ рд╕реАрдЦрд╛ рдЬрд╛рддрд╛ рд╣реИред
рд╕реАрдорд╛ рдореЗрдВ, рдпрджрд┐ рд╣рдо рдХрднреА рдЫреЛрдЯреЗ рдХрджрдореЛрдВ рдХреЗ рд╕рд╛рде рдЕрдирдВрдд рд╕рдВрдЦреНрдпрд╛ рдореЗрдВ рдмреНрд▓реЙрдХ рд▓реЗрддреЗ рд╣реИрдВ, рддреЛ рд╣рдореЗрдВ рдУрдбреАрдИ рдХреЗ рд░реВрдк рдореЗрдВ рдЫрд┐рдкреА рд╣реБрдИ рдкрд░рдд рдХреА рдирд┐рд░рдВрддрд░ рдЧрддрд┐рд╢реАрд▓рддрд╛ рдорд┐рд▓рддреА рд╣реИ, рдареАрдХ рдЙрд╕реА рддрд░рд╣ рдЬреИрд╕рд╛ рдХрд┐ рдКрдкрд░ рдерд╛ред
рдЗрдирдкреБрдЯ рд▓реЗрдпрд░ рд╕реЗ рд╢реБрд░реВ

рд╣рдо рдЖрдЙрдЯрдкреБрдЯ рд▓реЗрдпрд░ рдХреЛ рдкрд░рд┐рднрд╛рд╖рд┐рдд рдХрд░ рд╕рдХрддреЗ рд╣реИрдВ

рд╕рдордп рдкрд░ рдЗрд╕ ODE рдХреЗ рд╕рдорд╛рдзрд╛рди рдХреЗ рд░реВрдк рдореЗрдВ рдЯреАред
рдЕрдм рд╣рдо рдЧрд┐рди рд╕рдХрддреЗ рд╣реИрдВ

рд╕рднреА infinitesimal рдмреНрд▓реЙрдХреЛрдВ рдХреЗ рдмреАрдЪ рд╡рд┐рддрд░рд┐рдд (
рд╕рд╛рдЭрд╛ ) рдорд╛рдкрджрдВрдбреЛрдВ рдХреЗ рд░реВрдк рдореЗрдВред
MNIST рдкрд░ рддрдВрддреНрд░рд┐рдХрд╛ ODE рд╡рд╛рд╕реНрддреБрдХрд▓рд╛ рдХреЛ рдорд╛рдиреНрдп рдХрд░рдирд╛
рдЗрд╕ рднрд╛рдЧ рдореЗрдВ, рд╣рдо
рддрдВрддреНрд░рд┐рдХрд╛ ODE рдХреА рдХреНрд╖рдорддрд╛ рдХрд╛ рдкрд░реАрдХреНрд╖рдг рдЕрдзрд┐рдХ рдкрд░рд┐рдЪрд┐рдд рд╡рд╛рд╕реНрддреБрд╢рд┐рд▓реНрдк рдореЗрдВ рдШрдЯрдХреЛрдВ рдХреЗ рд░реВрдк рдореЗрдВ рдХрд░реЗрдВрдЧреЗред
рд╡рд┐рд╢реЗрд╖ рд░реВрдк рд╕реЗ, рд╣рдо MNIST рдХреНрд▓рд╛рд╕рд┐рдлрд╛рдпрд░ рдореЗрдВ
рдиреНрдпреВрд░рд▓ ODE рдХреЗ рд╕рд╛рде рдЕрд╡рд╢рд┐рд╖реНрдЯ рдмреНрд▓реЙрдХреЛрдВ рдХреЛ рдкреНрд░рддрд┐рд╕реНрдерд╛рдкрд┐рдд рдХрд░реЗрдВрдЧреЗред

рдХреЛрдб def norm(dim): return nn.BatchNorm2d(dim) def conv3x3(in_feats, out_feats, stride=1): return nn.Conv2d(in_feats, out_feats, kernel_size=3, stride=stride, padding=1, bias=False) def add_time(in_tensor, t): bs, c, w, h = in_tensor.shape return torch.cat((in_tensor, t.expand(bs, 1, w, h)), dim=1) class ConvODEF(ODEF): def __init__(self, dim): super(ConvODEF, self).__init__() self.conv1 = conv3x3(dim + 1, dim) self.norm1 = norm(dim) self.conv2 = conv3x3(dim + 1, dim) self.norm2 = norm(dim) def forward(self, x, t): xt = add_time(x, t) h = self.norm1(torch.relu(self.conv1(xt))) ht = add_time(h, t) dxdt = self.norm2(torch.relu(self.conv2(ht))) return dxdt class ContinuousNeuralMNISTClassifier(nn.Module): def __init__(self, ode): super(ContinuousNeuralMNISTClassifier, self).__init__() self.downsampling = nn.Sequential( nn.Conv2d(1, 64, 3, 1), norm(64), nn.ReLU(inplace=True), nn.Conv2d(64, 64, 4, 2, 1), norm(64), nn.ReLU(inplace=True), nn.Conv2d(64, 64, 4, 2, 1), ) self.feature = ode self.norm = norm(64) self.avg_pool = nn.AdaptiveAvgPool2d((1, 1)) self.fc = nn.Linear(64, 10) def forward(self, x): x = self.downsampling(x) x = self.feature(x) x = self.norm(x) x = self.avg_pool(x) shape = torch.prod(torch.tensor(x.shape[1:])).item() x = x.view(-1, shape) out = self.fc(x) return out func = ConvODEF(64) ode = NeuralODE(func) model = ContinuousNeuralMNISTClassifier(ode) if use_cuda: model = model.cuda() import torchvision img_std = 0.3081 img_mean = 0.1307 batch_size = 32 train_loader = torch.utils.data.DataLoader( torchvision.datasets.MNIST("data/mnist", train=True, download=True, transform=torchvision.transforms.Compose([ torchvision.transforms.ToTensor(), torchvision.transforms.Normalize((img_mean,), (img_std,)) ]) ), batch_size=batch_size, shuffle=True ) test_loader = torch.utils.data.DataLoader( torchvision.datasets.MNIST("data/mnist", train=False, download=True, transform=torchvision.transforms.Compose([ torchvision.transforms.ToTensor(), torchvision.transforms.Normalize((img_mean,), (img_std,)) ]) ), batch_size=128, shuffle=True ) optimizer = torch.optim.Adam(model.parameters()) def train(epoch): num_items = 0 train_losses = [] model.train() criterion = nn.CrossEntropyLoss() print(f"Training Epoch {epoch}...") for batch_idx, (data, target) in tqdm(enumerate(train_loader), total=len(train_loader)): if use_cuda: data = data.cuda() target = target.cuda() optimizer.zero_grad() output = model(data) loss = criterion(output, target) loss.backward() optimizer.step() train_losses += [loss.item()] num_items += data.shape[0] print('Train loss: {:.5f}'.format(np.mean(train_losses))) return train_losses def test(): accuracy = 0.0 num_items = 0 model.eval() criterion = nn.CrossEntropyLoss() print(f"Testing...") with torch.no_grad(): for batch_idx, (data, target) in tqdm(enumerate(test_loader), total=len(test_loader)): if use_cuda: data = data.cuda() target = target.cuda() output = model(data) accuracy += torch.sum(torch.argmax(output, dim=1) == target).item() num_items += data.shape[0] accuracy = accuracy * 100 / num_items print("Test Accuracy: {:.3f}%".format(accuracy)) n_epochs = 5 test() train_losses = [] for epoch in range(1, n_epochs + 1): train_losses += train(epoch) test() import pandas as pd plt.figure(figsize=(9, 5)) history = pd.DataFrame({"loss": train_losses}) history["cum_data"] = history.index * batch_size history["smooth_loss"] = history.loss.ewm(halflife=10).mean() history.plot(x="cum_data", y="smooth_loss", figsize=(12, 5), title="train error")
Testing... 100% 79/79 [00:01<00:00, 45.69it/s] Test Accuracy: 9.740% Training Epoch 1... 100% 1875/1875 [01:15<00:00, 24.69it/s] Train loss: 0.20137 Testing... 100% 79/79 [00:01<00:00, 46.64it/s] Test Accuracy: 98.680% Training Epoch 2... 100% 1875/1875 [01:17<00:00, 24.32it/s] Train loss: 0.05059 Testing... 100% 79/79 [00:01<00:00, 46.11it/s] Test Accuracy: 97.760% Training Epoch 3... 100% 1875/1875 [01:16<00:00, 24.63it/s] Train loss: 0.03808 Testing... 100% 79/79 [00:01<00:00, 45.65it/s] Test Accuracy: 99.000% Training Epoch 4... 100% 1875/1875 [01:17<00:00, 24.28it/s] Train loss: 0.02894 Testing... 100% 79/79 [00:01<00:00, 45.42it/s] Test Accuracy: 99.130% Training Epoch 5... 100% 1875/1875 [01:16<00:00, 24.67it/s] Train loss: 0.02424 Testing... 100% 79/79 [00:01<00:00, 45.89it/s] Test Accuracy: 99.170%

рдХреЗрд╡рд▓ 5 рдпреБрдЧреЛрдВ рдФрд░ 6 рдорд┐рдирдЯ рдХреЗ рдкреНрд░рд╢рд┐рдХреНрд╖рдг рдХреЗ рджреМрд░рд╛рди рдмрд╣реБрдд рдХрдард┐рди рдкреНрд░рд╢рд┐рдХреНрд╖рдг рдХреЗ рдмрд╛рдж, рдореЙрдбрд▓ рдкрд╣рд▓реЗ рд╣реА 1% рд╕реЗ рдХрдо рдХреА рдкрд░реАрдХреНрд╖рдг рддреНрд░реБрдЯрд┐ рддрдХ рдкрд╣реБрдВрдЪ рдЧрдпрд╛ рд╣реИред рд╣рдо рдХрд╣ рд╕рдХрддреЗ рд╣реИрдВ рдХрд┐
рддрдВрддреНрд░рд┐рдХрд╛ ODEs рдПрдХ рдШрдЯрдХ рдХреЗ
рд░реВрдк рдореЗрдВ рдЕрдзрд┐рдХ рдкрд╛рд░рдВрдкрд░рд┐рдХ рдиреЗрдЯрд╡рд░реНрдХ рдореЗрдВ рдПрдХреАрдХреГрдд рд╣реЛрддреЗ рд╣реИрдВред
рдЕрдкрдиреЗ рд▓реЗрдЦ рдореЗрдВ, рд▓реЗрдЦрдХ рдЗрд╕ рдХреНрд▓рд╛рд╕рд┐рдлрд╝рд╛рдпрд░ (ODE-Net) рдХреА рддреБрд▓рдирд╛ рдПрдХ рдирд┐рдпрдорд┐рдд рд░реВрдк рд╕реЗ рдкреВрд░реА рддрд░рд╣ рд╕реЗ рдЬреБрдбрд╝реЗ рдиреЗрдЯрд╡рд░реНрдХ рдХреЗ рд╕рд╛рде, рдПрдХ рд╕рдорд╛рди рдЖрд░реНрдХрд┐рдЯреЗрдХреНрдЪрд░ рдХреЗ рд╕рд╛рде ResNet рдХреЗ рд╕рд╛рде, рдФрд░ рдареАрдХ рдЙрд╕реА рд╡рд╛рд╕реНрддреБрдХрд▓рд╛ рдХреЗ рд╕рд╛рде рдХрд░рддреЗ рд╣реИрдВ, рдЬрд┐рд╕рдореЗрдВ рдврд╛рд▓ рд╕реАрдзреЗ
ODESolve (рд╕рдВрдпреБрдЧреНрдо рдврд╛рд▓ рд╡рд┐рдзрд┐ рдХреЗ рдмрд┐рдирд╛) рдореЗрдВ рд╕рдВрдЪрд╛рд▓рди рдХреЗ рдорд╛рдзреНрдпрдо рд╕реЗ рдкреНрд░рдЪрд╛рд░рд┐рдд рдХрд░рддрд╛ рд╣реИ ( рдЖрд░-рдиреЗрдЯ)ред
рдореВрд▓ рд▓реЗрдЦ рд╕реЗ рдЪрд┐рддреНрд░рдгредрдЙрдирдХреЗ рдЕрдиреБрд╕рд╛рд░, рдПрдХ 1-рд▓реЗрдпрд░ рдкреВрд░реА рддрд░рд╣ рд╕реЗ рдХрдиреЗрдХреНрдЯреЗрдб рдиреЗрдЯрд╡рд░реНрдХ рдХреЗ рд╕рд╛рде рд▓рдЧрднрдЧ рд╕рдорд╛рди рдорд╛рдкрджрдВрдбреЛрдВ рдХреА рд╕рдВрдЦреНрдпрд╛ рдХреЗ рд░реВрдк рдореЗрдВ
рдиреНрдпреВрд░рд▓ ODE рдореЗрдВ рдкрд░реАрдХреНрд╖рдг рдкрд░ рдПрдХ рдмрд╣реБрдд рдЕрдзрд┐рдХ рддреНрд░реБрдЯрд┐ рд╣реИ, рдПрдХ рд╣реА рддреНрд░реБрдЯрд┐ рд╡рд╛рд▓реЗ ResNet рдореЗрдВ рдмрд╣реБрдд рдЕрдзрд┐рдХ рдкреИрд░рд╛рдореАрдЯрд░ рд╣реИрдВ, рдФрд░ RK- рдиреЗрдЯ рдмрд┐рдирд╛ рд╕рдВрдпреБрдЧреНрдо рдЧреНрд░реЗрдбрд┐рдВрдЧ рд╡рд┐рдзрд┐ рдХреЗ рдереЛрдбрд╝реА рдЕрдзрд┐рдХ рддреНрд░реБрдЯрд┐ рд╣реИред рдФрд░ рдПрдХ рд░реИрдЦрд┐рдХ рд░реВрдк рд╕реЗ рдмрдврд╝рддреА рд╣реБрдИ рдореЗрдореЛрд░реА рдХреА рдЦрдкрдд рдХреЗ рд╕рд╛рде (рдЕрдиреБрдореЗрдп рддреНрд░реБрдЯрд┐
рдЬрд┐рддрдиреА рдЫреЛрдЯреА рд╣реЛ,
рдЙрддрдиреЗ рдЕрдзрд┐рдХ рдХрджрдо
ODESolve рд╣реЛрдиреЗ рдЪрд╛рд╣рд┐рдП, рдЬреЛ рд░реИрдЦрд┐рдХ рд░реВрдк рд╕реЗ рдЪрд░рдгреЛрдВ рдХреА рд╕рдВрдЦреНрдпрд╛ рдХреЗ рд╕рд╛рде рдореЗрдореЛрд░реА рдХреА рдЦрдкрдд рдХреЛ рдмрдврд╝рд╛рддреЗ рд╣реИрдВ)ред
рд▓реЗрдЦрдХ рдпрд╣рд╛рдВ рд╕рд░рд▓ рдЗрд▓рд░ рд╡рд┐рдзрд┐ рдХреЗ рд╡рд┐рдкрд░реАрдд, рдЙрдирдХреЗ рдХрд╛рд░реНрдпрд╛рдиреНрд╡рдпрди рдореЗрдВ рдЕрдиреБрдХреВрд▓реА рдЖрдХрд╛рд░ рдХреЗ рд╕рд╛рде рдирд┐рд╣рд┐рдд рд░рди-рдХреБрдЯреНрдЯрд╛ рд╡рд┐рдзрд┐ рдХрд╛ рдЙрдкрдпреЛрдЧ рдХрд░рддреЗ рд╣реИрдВред рд╡реЗ рдирдИ рд╡рд╛рд╕реНрддреБрдХрд▓рд╛ рдХреЗ рдХреБрдЫ рдЧреБрдгреЛрдВ рдХрд╛ рднреА рдЕрдзреНрдпрдпрди рдХрд░рддреЗ рд╣реИрдВред
ODE- рдиреЗрдЯ рдлрд╝реАрдЪрд░ (NFE рдлреЙрд░рд╡рд░реНрдб - рдкреНрд░рддреНрдпрдХреНрд╖ рдкрд╛рд╕ рдореЗрдВ рдлрд╝рдВрдХреНрд╢рди рдХреА рд╕рдВрдЦреНрдпрд╛)
рдореВрд▓ рд▓реЗрдЦ рд╕реЗ рдЪрд┐рддреНрд░рдгред- (рдП) рд╕рдВрдЦреНрдпрд╛рддреНрдордХ рддреНрд░реБрдЯрд┐ рдХреЗ рд╕реНрд╡реАрдХрд╛рд░реНрдп рд╕реНрддрд░ рдХреЛ рдмрджрд▓рдиреЗ рд╕реЗ рдкреНрд░рддреНрдпрдХреНрд╖ рд╡рд┐рддрд░рдг рдореЗрдВ рдЪрд░рдгреЛрдВ рдХреА рд╕рдВрдЦреНрдпрд╛ рдореЗрдВ рдкрд░рд┐рд╡рд░реНрддрди рд╣реЛрддрд╛ рд╣реИред
- (b) рдкреНрд░рддреНрдпрдХреНрд╖ рд╡рд┐рддрд░рдг рдкрд░ рдЦрд░реНрдЪ рдХрд┐рдпрд╛ рдЧрдпрд╛ рд╕рдордп рдлрд╝рдВрдХреНрд╢рди рдХреА рдЧрдгрдирд╛ рдХреА рд╕рдВрдЦреНрдпрд╛ рдХреЗ рд▓рд┐рдП рдЖрдиреБрдкрд╛рддрд┐рдХ рд╣реИред
- (c) рд╡рд╛рдкрд╕ рдкреНрд░рд╕рд╛рд░ рдХреЗ рд▓рд┐рдП рдХрд╛рд░реНрдп рдХреА рдЧрдгрдирд╛ рдХреА рд╕рдВрдЦреНрдпрд╛ рдкреНрд░рддреНрдпрдХреНрд╖ рдкреНрд░рд╕рд╛рд░ рдХрд╛ рд▓рдЧрднрдЧ рдЖрдзрд╛ рд╣реИ, рдЬреЛ рдЗрдВрдЧрд┐рдд рдХрд░рддрд╛ рд╣реИ рдХрд┐ рд╕рдВрдпреБрдЧреНрдо рдЧреНрд░реЗрдбрд┐рдПрдВрдЯ рд╡рд┐рдзрд┐ рд╕реАрдзреЗ ODESolve рдХреЗ рдорд╛рдзреНрдпрдо рд╕реЗ рдЧреНрд░реЗрдбрд┐рдПрдВрдЯ рдХреЛ рдлреИрд▓рд╛рдиреЗ рдХреА рддреБрд▓рдирд╛ рдореЗрдВ рдЕрдзрд┐рдХ рдХрдореНрдкреНрдпреВрдЯреЗрд╢рдирд▓ рд░реВрдк рд╕реЗ рдХреБрд╢рд▓ рд╣реЛ рд╕рдХрддреА рд╣реИред
- (рдШ) рдУрдбреАрдИ-рдиреЗрдЯ рдЕрдзрд┐рдХ рд╕реЗ рдЕрдзрд┐рдХ рдкреНрд░рд╢рд┐рдХреНрд╖рд┐рдд рд╣реЛ рдЬрд╛рддрд╛ рд╣реИ, рдЗрд╕рдХреЗ рд▓рд┐рдП рдлрд╝рдВрдХреНрд╢рди рдХреА рдЕрдзрд┐рдХ рд╕реЗ рдЕрдзрд┐рдХ рдЧрдгрдирд╛ (рдПрдХ рдХрднреА рдЫреЛрдЯрд╛ рдХрджрдо) рдХреА рдЖрд╡рд╢реНрдпрдХрддрд╛ рд╣реЛрддреА рд╣реИ, рд╕рдВрднрд╡рддрдГ рдореЙрдбрд▓ рдХреА рдмрдврд╝рддреА рдЬрдЯрд┐рд▓рддрд╛ рдХреЗ рд▓рд┐рдП рдЕрдиреБрдХреВрд▓ред
рдЯрд╛рдЗрдо рд╕реАрд░реАрдЬрд╝ рдореЙрдбрд▓рд┐рдВрдЧ рдХреЗ рд▓рд┐рдП рд╣рд┐рдбрди рдЬрдирд░реЗрдЯрд┐рд╡ рдлрдВрдХреНрд╢рди
рддрдВрддреНрд░рд┐рдХрд╛ ODE рдкрде рдХреЗ рдЕрдЬреНрдЮрд╛рдд рд╕реНрдерд╛рди рдкрд░ рд╣реЛрдиреЗ рдкрд░ рднреА рдирд┐рд░рдВрддрд░ рдзрд╛рд░рд╛рд╡рд╛рд╣рд┐рдХ рдбреЗрдЯрд╛ рдХреЛ рд╕рдВрд╕рд╛рдзрд┐рдд рдХрд░рдиреЗ рдХреЗ рд▓рд┐рдП рдЙрдкрдпреБрдХреНрдд рд╣реИред
рдЗрд╕ рдЕрдиреБрднрд╛рдЧ рдореЗрдВ, рд╣рдо рдкреНрд░рдпреЛрдЧ рдХрд░реЗрдВрдЧреЗ
рдФрд░ рддрдВрддреНрд░рд┐рдХрд╛ ODE рдХрд╛ рдЙрдкрдпреЛрдЧ рдХрд░рдХреЗ рдирд┐рд░рдВрддрд░ рдЕрдиреБрдХреНрд░рдо рдХреА рдкреАрдврд╝реА рдХреЛ рдмрджрд▓ рджреЗрдВрдЧреЗ, рдФрд░ рд╕реАрдЦреЗ рдЧрдП рдЫрд┐рдкреЗ рд╣реБрдП рд╕реНрдерд╛рди рдкрд░ рдПрдХ рдирдЬрд╝рд░ рдбрд╛рд▓реЗрдВрдЧреЗред
рд▓реЗрдЦрдХ рдЗрд╕рдХреА рддреБрд▓рдирд╛ рдЖрд╡рд░реНрддрдХ рдиреЗрдЯрд╡рд░реНрдХ рджреНрд╡рд╛рд░рд╛ рдЙрддреНрдкрдиреНрди рд╕рдорд╛рди рдЕрдиреБрдХреНрд░рдореЛрдВ рд╕реЗ рднреА рдХрд░рддреЗ рд╣реИрдВред
рдпрд╣рд╛рдБ рдкреНрд░рдпреЛрдЧ рд▓реЗрдЦрдХ рдХреЗ рднрдВрдбрд╛рд░ рдореЗрдВ рд╕рдВрдмрдВрдзрд┐рдд рдЙрджрд╛рд╣рд░рдг рд╕реЗ рдереЛрдбрд╝рд╛ рдЕрд▓рдЧ рд╣реИ, рдпрд╣рд╛рдБ рдкрдереЛрдВ рдХрд╛ рдЕрдзрд┐рдХ рд╡рд┐рд╡рд┐рдз рд╕реЗрдЯ рд╣реИред
рдбреЗрдЯрд╛
рдкреНрд░рд╢рд┐рдХреНрд╖рдг рдбреЗрдЯрд╛ рдореЗрдВ рдпрд╛рджреГрдЪреНрдЫрд┐рдХ рд╕рд░реНрдкрд┐рд▓ рд╣реЛрддреЗ рд╣реИрдВ, рдЬрд┐рдирдореЗрдВ рд╕реЗ рдЖрдзреЗ рджрдХреНрд╖рд┐рдгрд╛рд╡рд░реНрдд рд╣реЛрддреЗ рд╣реИрдВ, рдФрд░ рджреВрд╕рд░рд╛ рд╡рд╛рдорд╛рд╡рд░реНрддред рдЗрд╕рдХреЗ рдЕрд▓рд╛рд╡рд╛, рдЗрди рд╕рд░реНрдкрд┐рд▓реЛрдВ рд╕реЗ рдпрд╛рджреГрдЪреНрдЫрд┐рдХ рдХреНрд░рдореЛрдВ рдХрд╛ рдирдореВрдирд╛ рд▓рд┐рдпрд╛ рдЬрд╛рддрд╛ рд╣реИ, рдЬреЛ рд╡рд┐рдкрд░реАрдд рджрд┐рд╢рд╛ рдореЗрдВ рдХреЛрдбрд┐рдВрдЧ рдкреБрдирд░рд╛рд╡реГрддреНрддрд┐ рдореЙрдбрд▓ рджреНрд╡рд╛рд░рд╛ рд╕рдВрд╕рд╛рдзрд┐рдд рд╣реЛрддреЗ рд╣реИрдВ, рдЬрд┐рд╕рд╕реЗ рдПрдХ рдкреНрд░рд╛рд░рдВрднрд┐рдХ рдЫрд┐рдкреЗ рд╣реБрдП рд░рд╛рдЬреНрдп рдХреЛ рдЬрдиреНрдо рджрд┐рдпрд╛ рдЬрд╛рддрд╛ рд╣реИ, рдЬреЛ рддрдм рд╡рд┐рдХрд╕рд┐рдд рд╣реЛрддрд╛ рд╣реИ, рдЫрд┐рдкреА рд╣реБрдИ рдЬрдЧрд╣ рдореЗрдВ рдкреНрд░рдХреНрд╖реЗрдкрд╡рдХреНрд░ рдмрдирд╛рддрд╛ рд╣реИред рдЗрд╕ рдЕрд╡реНрдпрдХреНрдд рдкрде рдХреЛ рддрдм рдбреЗрдЯрд╛ рд╕реНрдерд╛рди рдкрд░ рдореИрдк рдХрд┐рдпрд╛ рдЬрд╛рддрд╛ рд╣реИ рдФрд░ рддреБрд▓рдирд╛ рдХрд┐рдП рдЧрдП рдкрд░рд┐рдгрд╛рдо рдХреЗ рд╕рд╛рде рддреБрд▓рдирд╛ рдХреА рдЬрд╛рддреА рд╣реИред рдЗрд╕ рдкреНрд░рдХрд╛рд░, рдореЙрдбрд▓ рдПрдХ рдбрд╛рдЯрд╛рд╕реЗрдЯ рдХреЗ рд╕рдорд╛рди рдкреНрд░рдХреНрд╖реЗрдкрд╡рдХреНрд░ рдЙрддреНрдкрдиреНрди рдХрд░рдиреЗ рдХреЗ рд▓рд┐рдП рд╕реАрдЦрддрд╛ рд╣реИред

рдбреЗрдЯрд╛рд╕реЗрдЯ рд╕рд░реНрдкрд┐рд▓ рдХреЗ рдЙрджрд╛рд╣рд░рдгрдПрдХ рдЬреЗрдиреЗрд░рд┐рдХ рдореЙрдбрд▓ рдХреЗ рд░реВрдк рдореЗрдВ рд╡реАрдПрдИ
рдирдореВрдирд╛ рдкреНрд░рдХреНрд░рд┐рдпрд╛ рдХреЗ рдорд╛рдзреНрдпрдо рд╕реЗ рдЬрдирд░реЗрдЯрд┐рд╡ рдореЙрдбрд▓:
рдЬрд┐рд╕реЗ рд╡реИрд░рд┐рдПрдмрд▓ рдСрдЯреЛ-рдПрдирдХреЛрдбрд░ рджреГрд╖реНрдЯрд┐рдХреЛрдг рдХрд╛ рдЙрдкрдпреЛрдЧ рдХрд░рдХреЗ рдкреНрд░рд╢рд┐рдХреНрд╖рд┐рдд рдХрд┐рдпрд╛ рдЬрд╛ рд╕рдХрддрд╛ рд╣реИред
- рдорд╛рдкрджрдВрдбреЛрдВ рдХреЛ рдкреНрд░рд╛рдкреНрдд рдХрд░рдиреЗ рдХреЗ рд▓рд┐рдП рд╕рдордп рдореЗрдВ рд╡рд╛рдкрд╕ рд╕рдордп рдЕрдиреБрдХреНрд░рдо рдХреЗ рдорд╛рдзреНрдпрдо рд╕реЗ рдПрдХ рдЖрд╡рд░реНрддрдХ рдПрдирдХреЛрдбрд░ рдХреЗ рдорд╛рдзреНрдпрдо рд╕реЗ рдЬрд╛рдУ
ред
рдкрд░рд┐рд╡рд░реНрддрдирд╢реАрд▓ рдкрд╢реНрдЪ рд╡рд┐рддрд░рдг, рдФрд░ рдлрд┐рд░ рдЙрд╕рд╕реЗ рдирдореВрдирд╛:
- рдЫрд┐рдкреЗ рд╣реБрдП рдкреНрд░рдХреНрд╖реЗрдкрд╡рдХреНрд░ рдкреНрд░рд╛рдкреНрдд рдХрд░реЗрдВ:
- рдХрд┐рд╕реА рдЕрдиреНрдп рддрдВрддреНрд░рд┐рдХрд╛ рдиреЗрдЯрд╡рд░реНрдХ рдХрд╛ рдЙрдкрдпреЛрдЧ рдХрд░ рдбреЗрдЯрд╛ рдореЗрдВ рдПрдХ рдкрде рдХреЗ рд▓рд┐рдП рдПрдХ рдЫрд┐рдкреЗ рд╣реБрдП рдорд╛рд░реНрдЧ рдХреЛ рдореИрдк рдХрд░реЗрдВ:

- рд╕реИрдВрдкрд▓ рдХрд┐рдП рдЧрдП рдкрде рдХреЗ рд▓рд┐рдП рд╡реИрдзрддрд╛ рдХреА рдирд┐рдЪрд▓реА рд╕реАрдорд╛ (ELBO) рдХреЗ рдЖрдХрд▓рди рдХреЛ рдЕрдзрд┐рдХрддрдо рдХрд░реЗрдВ:
рдФрд░ рдЧреЙрд╕рд┐рдпрди рдХреЗ рдмрд╛рдж рдХреЗ рд╡рд┐рддрд░рдг рдХреЗ рдорд╛рдорд▓реЗ рдореЗрдВ

рдФрд░ рдЬреНрдЮрд╛рдд рд╢реЛрд░ рд╕реНрддрд░

:
рдПрдХ рдЫрд┐рдкреЗ рд╣реБрдП ODE рдореЙрдбрд▓ рдХреЗ рдЕрднрд┐рдХрд▓рди рдЧреНрд░рд╛рдл рдХреЛ рдирд┐рдореНрдирд╛рдиреБрд╕рд╛рд░ рджрд░реНрд╢рд╛рдпрд╛ рдЬрд╛ рд╕рдХрддрд╛ рд╣реИ
рдореВрд▓ рд▓реЗрдЦ рд╕реЗ рдЪрд┐рддреНрд░рдгредрдЗрд╕ рдореЙрдбрд▓ рдХреЛ рддрдм рдкрд░реАрдХреНрд╖рдг рдХрд┐рдпрд╛ рдЬрд╛ рд╕рдХрддрд╛ рд╣реИ рдХрд┐ рдпрд╣ рдХреЗрд╡рд▓ рдкреНрд░рд╛рд░рдВрднрд┐рдХ рдЯрд┐рдкреНрдкрдгрд┐рдпреЛрдВ рдХрд╛ рдЙрдкрдпреЛрдЧ рдХрд░рдХреЗ рдкрде рдХреЛ рдХреИрд╕реЗ рдкреНрд░рдХреНрд╖реЗрдкрд┐рдд рдХрд░рддрд╛ рд╣реИред
рдХреЛрдбрдореЙрдбрд▓ рдкрд░рд┐рднрд╛рд╖рд┐рдд рдХрд░реЗрдВ
class RNNEncoder(nn.Module): def __init__(self, input_dim, hidden_dim, latent_dim): super(RNNEncoder, self).__init__() self.input_dim = input_dim self.hidden_dim = hidden_dim self.latent_dim = latent_dim self.rnn = nn.GRU(input_dim+1, hidden_dim) self.hid2lat = nn.Linear(hidden_dim, 2*latent_dim) def forward(self, x, t):
рдбреЗрдЯрд╕реЗрдЯ рдЬрдирд░реЗрд╢рди
t_max = 6.29*5 n_points = 200 noise_std = 0.02 num_spirals = 1000 index_np = np.arange(0, n_points, 1, dtype=np.int) index_np = np.hstack([index_np[:, None]]) times_np = np.linspace(0, t_max, num=n_points) times_np = np.hstack([times_np[:, None]] * num_spirals) times = torch.from_numpy(times_np[:, :, None]).to(torch.float32)
рдЯреНрд░реЗрдирд┐рдВрдЧ
vae = ODEVAE(2, 64, 6) vae = vae.cuda() if use_cuda: vae = vae.cuda() optim = torch.optim.Adam(vae.parameters(), betas=(0.9, 0.999), lr=0.001) preload = False n_epochs = 20000 batch_size = 100 plot_traj_idx = 1 plot_traj = orig_trajs[:, plot_traj_idx:plot_traj_idx+1] plot_obs = samp_trajs[:, plot_traj_idx:plot_traj_idx+1] plot_ts = samp_ts[:, plot_traj_idx:plot_traj_idx+1] if use_cuda: plot_traj = plot_traj.cuda() plot_obs = plot_obs.cuda() plot_ts = plot_ts.cuda() if preload: vae.load_state_dict(torch.load("models/vae_spirals.sd")) for epoch_idx in range(n_epochs): losses = [] train_iter = gen_batch(batch_size) for x, t in train_iter: optim.zero_grad() if use_cuda: x, t = x.cuda(), t.cuda() max_len = np.random.choice([30, 50, 100]) permutation = np.random.permutation(t.shape[0]) np.random.shuffle(permutation) permutation = np.sort(permutation[:max_len]) x, t = x[permutation], t[permutation] x_p, z, z_mean, z_log_var = vae(x, t) z_var = torch.exp(z_log_var) kl_loss = -0.5 * torch.sum(1 + z_log_var - z_mean**2 - z_var, -1) loss = 0.5 * ((x-x_p)**2).sum(-1).sum(0) / noise_std**2 + kl_loss loss = torch.mean(loss) loss /= max_len loss.backward() optim.step() losses.append(loss.item()) print(f"Epoch {epoch_idx}") frm, to, to_seed = 0, 200, 50 seed_trajs = samp_trajs[frm:to_seed] ts = samp_ts[frm:to] if use_cuda: seed_trajs = seed_trajs.cuda() ts = ts.cuda() samp_trajs_p = to_np(vae.generate_with_seed(seed_trajs, ts)) fig, axes = plt.subplots(nrows=3, ncols=3, figsize=(15, 9)) axes = axes.flatten() for i, ax in enumerate(axes): ax.scatter(to_np(seed_trajs[:, i, 0]), to_np(seed_trajs[:, i, 1]), c=to_np(ts[frm:to_seed, i, 0]), cmap=cm.plasma) ax.plot(to_np(orig_trajs[frm:to, i, 0]), to_np(orig_trajs[frm:to, i, 1])) ax.plot(samp_trajs_p[:, i, 0], samp_trajs_p[:, i, 1]) plt.show() print(np.mean(losses), np.median(losses)) clear_output(wait=True) spiral_0_idx = 3 spiral_1_idx = 6 homotopy_p = Tensor(np.linspace(0., 1., 10)[:, None]) vae = vae if use_cuda: homotopy_p = homotopy_p.cuda() vae = vae.cuda() spiral_0 = orig_trajs[:, spiral_0_idx:spiral_0_idx+1, :] spiral_1 = orig_trajs[:, spiral_1_idx:spiral_1_idx+1, :] ts_0 = samp_ts[:, spiral_0_idx:spiral_0_idx+1, :] ts_1 = samp_ts[:, spiral_1_idx:spiral_1_idx+1, :] if use_cuda: spiral_0, ts_0 = spiral_0.cuda(), ts_0.cuda() spiral_1, ts_1 = spiral_1.cuda(), ts_1.cuda() z_cw, _ = vae.encoder(spiral_0, ts_0) z_cc, _ = vae.encoder(spiral_1, ts_1) homotopy_z = z_cw * (1 - homotopy_p) + z_cc * homotopy_p t = torch.from_numpy(np.linspace(0, 6*np.pi, 200)) t = t[:, None].expand(200, 10)[:, :, None].cuda() t = t.cuda() if use_cuda else t hom_gen_trajs = vae.decoder(homotopy_z, t) fig, axes = plt.subplots(nrows=2, ncols=5, figsize=(15, 5)) axes = axes.flatten() for i, ax in enumerate(axes): ax.plot(to_np(hom_gen_trajs[:, i, 0]), to_np(hom_gen_trajs[:, i, 1])) plt.show() torch.save(vae.state_dict(), "models/vae_spirals.sd")
рдПрдХ рд░рд╛рдд рдкреНрд░рд╢рд┐рдХреНрд╖рдг рдХреЗ рдмрд╛рдж рдРрд╕рд╛ рд╣реА рд╣реЛрддрд╛ рд╣реИрдЕрдВрдХ рдореВрд▓ рдкреНрд░рдХреНрд╖реЗрдкрд╡рдХреНрд░ (рдиреАрд▓рд╛) рдХреЗ рд╢реЛрд░ рдЕрд╡рд▓реЛрдХрди рд╣реИрдВ,
рдкреАрд▓реЗ рдЗрдирдкреБрдЯ рдХреЗ рд░реВрдк рдореЗрдВ рдмрд┐рдВрджреБрдУрдВ рдХрд╛ рдЙрдкрдпреЛрдЧ рдХрд░рддреЗ рд╣реБрдП, рдкреБрдирд░реНрдирд┐рд░реНрдорд┐рдд рдФрд░ рдкреНрд░рдХреНрд╖реЗрдкрд┐рдд рдкреНрд░рдХреНрд╖реЗрдкрд╡рдХреНрд░ рд╣реИрдВред
рдбреЙрдЯ рдХрд╛ рд░рдВрдЧ рд╕рдордп рджрд┐рдЦрд╛рддрд╛ рд╣реИредрдХреБрдЫ рдЙрджрд╛рд╣рд░рдгреЛрдВ рдХреЗ рдкреБрдирд░реНрдирд┐рд░реНрдорд╛рдг рдмрд╣реБрдд рдЕрдЪреНрдЫреЗ рдирд╣реАрдВ рд▓рдЧрддреЗ рд╣реИрдВред рд╣реЛ рд╕рдХрддрд╛ рд╣реИ рдХрд┐ рдореЙрдбрд▓ рдкрд░реНрдпрд╛рдкреНрдд рдЬрдЯрд┐рд▓ рдирд╣реАрдВ рд╣реИ рдпрд╛ рд▓рдВрдмреЗ рд╕рдордп рддрдХ рдЕрдзреНрдпрдпрди рдирд╣реАрдВ рдХрд┐рдпрд╛ рдЧрдпрд╛ рд╣реИред рдХрд┐рд╕реА рднреА рдорд╛рдорд▓реЗ рдореЗрдВ, рдкреБрдирд░реНрдирд┐рд░реНрдорд╛рдг рдмрд╣реБрдд рд╣реА рдЙрдЪрд┐рдд рджрд┐рдЦрддрд╛ рд╣реИредрдЕрдм рджреЗрдЦрддреЗ рд╣реИрдВ рдХрд┐ рдЕрдЧрд░ рд╣рдо рдПрдХ рдХреНрд▓реЙрдХрд╡рд╛рдЗрдЬ рдкреНрд░рдХреНрд╖реЗрдкрд╡рдХреНрд░ рдореЗрдВ рдПрдХ рдЫрд┐рдкреЗ рд╣реБрдП рдЪрд░ рдХреЛ рдПрдХ рдПрдВрдЯреА-рдХреНрд▓реЙрдХрд╡рд╛рдЗрдЬ рдкреНрд░рдХреНрд╖реЗрдкрд╡рдХреНрд░ рдореЗрдВ рдкреНрд░рдХреНрд╖реЗрдкрд┐рдд рдХрд░рддреЗ рд╣реИрдВ, рддреЛ рдХреНрдпрд╛ рд╣реЛрддрд╛ рд╣реИредрд▓реЗрдЦрдХ рдиреНрдпреВрд░рд▓ ODE рдФрд░ рдПрдХ рд╕рд░рд▓ рдкреБрдирд░рд╛рд╡рд░реНрддреА рдиреЗрдЯрд╡рд░реНрдХ рдХреЗ рдмреАрдЪ рдкреБрдирд░реНрдирд┐рд░реНрдорд╛рдг рдФрд░ рдкрде рдкреНрд░рдХреНрд╖реЗрдк рдХреА рддреБрд▓рдирд╛ рднреА рдХрд░рддреЗ рд╣реИрдВ редрдореВрд▓ рд▓реЗрдЦ рд╕реЗ рдЪрд┐рддреНрд░рдгредрд▓рдЧрд╛рддрд╛рд░ рд╕рд╛рдорд╛рдиреНрдп рд╣реЛрдиреЗ рд╡рд╛рд▓реА рдзрд╛рд░рд╛рдПрдБ
. , , (, ), .
, , .
,
.
, ,
, , , .
:
( ) ( ) ;
-X ┬л┬╗ ( ) ┬л┬╗ ( ).
bekemax .
Neural ODEs . !