рдорд┐рд╢реНрд░рдг рдШрдирддреНрд╡ рдиреЗрдЯрд╡рд░реНрдХ


рд╕рднреА рдХреЛ рдирдорд╕реНрдХрд╛рд░!

рдХреЗ рд░реВрдк рдореЗрдВ рдЖрдк рдЕрдиреБрдорд╛рди рд▓рдЧрд╛рдпрд╛ рд╣реИ, рддрдВрддреНрд░рд┐рдХрд╛ рдиреЗрдЯрд╡рд░реНрдХ рдФрд░ рдорд╢реАрди рд╕реАрдЦрдиреЗ рдХреЗ рдмрд╛рд░реЗ рдореЗрдВ рдмрд╛рдд рдХрд░рддреЗ рд╣реИрдВред рдирд╛рдо рд╕реЗ рдпрд╣ рд╕реНрдкрд╖реНрдЯ рд╣реИ рдХрд┐ рдорд┐рдХреНрд╕рдЪрд░ рдбреЗрдВрд╕рд┐рдЯреА рдиреЗрдЯрд╡рд░реНрдХреНрд╕ рдХреЗ рдмрд╛рд░реЗ рдореЗрдВ рдХреНрдпрд╛ рдмрддрд╛рдпрд╛ рдЬрд╛рдПрдЧрд╛, рдлрд┐рд░ рд╕рд┐рд░реНрдл рдПрдордбреАрдПрди, рдореИрдВ рдирд╛рдо рдХрд╛ рдЕрдиреБрд╡рд╛рдж рдирд╣реАрдВ рдХрд░рдирд╛ рдЪрд╛рд╣рддрд╛ рдФрд░ рдЗрд╕реЗ рдЫреЛрдбрд╝ рджреЗрдирд╛ рдЪрд╛рд╣рддрд╛ рд╣реВрдВред рд╣рд╛рдВ, рд╣рд╛рдВ, рд╣рд╛рдВ ... рдереЛрдбрд╝рд╛ рдЙрдмрд╛рдК рдЧрдгрд┐рдд рдФрд░ рд╕рдВрднрд╛рд╡рдирд╛ рд╕рд┐рджреНрдзрд╛рдВрдд рд╣реЛрдЧрд╛, рд▓реЗрдХрд┐рди рдЗрд╕рдХреЗ рдмрд┐рдирд╛, рджреБрд░реНрднрд╛рдЧреНрдп рд╕реЗ, рдпрд╛ рд╕реМрднрд╛рдЧреНрдп рд╕реЗ, рдпрд╣ рдЖрдкрдХреЛ рддрдп рдХрд░рдирд╛ рд╣реИ рдХрд┐ рдХреНрдпрд╛ рдорд╢реАрди рд╕реАрдЦрдиреЗ рдХреА рджреБрдирд┐рдпрд╛ рдХреА рдХрд▓реНрдкрдирд╛ рдХрд░рдирд╛ рдореБрд╢реНрдХрд┐рд▓ рд╣реИред рд▓реЗрдХрд┐рди рдореИрдВ рдЖрдкрдХреЛ рдЖрд╢реНрд╡рд╕реНрдд рдХрд░рдиреЗ рдХреЗ рд▓рд┐рдП рдЬрд▓реНрджрдмрд╛рдЬреА рдХрд░рддрд╛ рд╣реВрдВ, рдпрд╣ рдЕрдкреЗрдХреНрд╖рд╛рдХреГрдд рдЫреЛрдЯрд╛ рд╣реЛрдЧрд╛ рдФрд░ рдпрд╣ рдмрд╣реБрдд рдореБрд╢реНрдХрд┐рд▓ рдирд╣реАрдВ рд╣реЛрдЧрд╛ред рд╡реИрд╕реЗ рднреА, рдЖрдк рдЗрд╕реЗ рдЫреЛрдбрд╝ рд╕рдХрддреЗ рд╣реИрдВ, рд▓реЗрдХрд┐рди рдмрд╕ рдкрд╛рдпрдерди рдФрд░ рдкреНрдпреЛрд░реЛрдЪ рдореЗрдВ рдереЛрдбрд╝реА рдорд╛рддреНрд░рд╛ рдореЗрдВ рдХреЛрдб рдХреЛ рджреЗрдЦреЗрдВ, рдпрд╣ рд╕рд╣реА рд╣реИ, рд╣рдо PyTorch рдХрд╛ рдЙрдкрдпреЛрдЧ рдХрд░рдХреЗ рдиреЗрдЯрд╡рд░реНрдХ, рд╕рд╛рде рд╣реА рдкрд░рд┐рдгрд╛рдореЛрдВ рдХреЗ рд╕рд╛рде рд╡рд┐рднрд┐рдиреНрди рдЧреНрд░рд╛рдлрд╝ рд▓рд┐рдЦреЗрдВрдЧреЗред рд▓реЗрдХрд┐рди рд╕рдмрд╕реЗ рдорд╣рддреНрд╡рдкреВрд░реНрдг рдмрд╛рдд рдпрд╣ рд╣реИ рдХрд┐ рдПрдордбреА рдиреЗрдЯрд╡рд░реНрдХ рдХреНрдпрд╛ рд╣реИрдВ, рдЗрд╕реЗ рдереЛрдбрд╝рд╛ рд╕рдордЭрдиреЗ рдФрд░ рд╕рдордЭрдиреЗ рдХрд╛ рдЕрд╡рд╕рд░ рдорд┐рд▓реЗрдЧрд╛ред

рдЦреИрд░, рдЪрд▓реЛ рд╢реБрд░реВ рд╣реЛ рдЬрд╛рдУ!



рд╡рд╛рдкрд╕реА


рд╢реБрд░реВ рдХрд░рдиреЗ рдХреЗ рд▓рд┐рдП, рдЖрдЗрдП рдЕрдкрдиреЗ рдЬреНрдЮрд╛рди рдХреЛ рдереЛрдбрд╝рд╛ рддрд╛рдЬрд╝рд╛ рдХрд░реЗрдВ рдФрд░ рдпрд╛рдж рдХрд░реЗрдВ, рдХрд╛рдлреА рд╕рдВрдХреНрд╖реЗрдк рдореЗрдВ, рд░реИрдЦрд┐рдХ рдкреНрд░рддрд┐рдЧрдорди рдХреНрдпрд╛ рд╣реИ ред

рд╣рдорд╛рд░реЗ рдкрд╛рд╕ рдПрдХ рд╡реЗрдХреНрдЯрд░ рд╣реИ X = \ {x_1, x_2, ..., x_n \}X = \ {x_1, x_2, ..., x_n \} рд╣рдореЗрдВ рдореВрд▓реНрдп рдХрд╛ рдЕрдиреБрдорд╛рди рд▓рдЧрд╛рдиреЗ рдХреА рдЖрд╡рд╢реНрдпрдХрддрд╛ рд╣реИ рдпрдп , рдЬреЛ рдХрд┐рд╕реА рддрд░рд╣ рдирд┐рд░реНрднрд░ рдХрд░рддрд╛ рд╣реИ X рдХреБрдЫ рд░реИрдЦрд┐рдХ рдореЙрдбрд▓ рдХрд╛ рдЙрдкрдпреЛрдЧ рдХрд░рдирд╛:

 hatY=XT hat beta

рддреНрд░реБрдЯрд┐ рдлрд╝рдВрдХреНрд╢рди рдХреЗ рд░реВрдк рдореЗрдВ, рд╣рдо рдЪреБрдХрддрд╛ рддреНрд░реБрдЯрд┐ рдХрд╛ рдЙрдкрдпреЛрдЧ рдХрд░реЗрдВрдЧреЗ:

SE (\ Beta) = \ sum_ {i = 1} ^ n (y_i- \ hat {y} _i) ^ 2 = \ sum_ {i = 1} ^ N (y_i-x_i ^ T \ hat {рдмреАрдЯрд╛ рдмреАрдЯрд╛}} ) ^ 2

рдПрд╕рдИ рдХреЗ рд╡реНрдпреБрддреНрдкрдиреНрди рд▓реЗрдиреЗ рдФрд░ рд╢реВрдиреНрдп рдХреЗ рд▓рд┐рдП рдЗрд╕рдХрд╛ рдореВрд▓реНрдп рдирд┐рд░реНрдзрд╛рд░рд┐рдд рдХрд░рдХреЗ рдЗрд╕ рд╕рдорд╕реНрдпрд╛ рдХреЛ рд╕реАрдзреЗ рд╣рд▓ рдХрд┐рдпрд╛ рдЬрд╛ рд╕рдХрддрд╛ рд╣реИ:

 frac deltaSE( Beta) delta beta=2XT( mathbfyтИТX beta)=0

рдЗрд╕ рдкреНрд░рдХрд╛рд░, рд╣рдо рдмрд╕ рдЗрд╕рдХрд╛ рдиреНрдпреВрдирддрдо рдкрд╛рддреЗ рд╣реИрдВ, рдФрд░ рдПрд╕рдИ рдПрдХ рджреНрд╡рд┐рдШрд╛рдд рдлрд╝рдВрдХреНрд╢рди рд╣реИ, рдЬрд┐рд╕рдХрд╛ рдЕрд░реНрде рд╣реИ рдХрд┐ рдиреНрдпреВрдирддрдо рд╣рдореЗрд╢рд╛ рдореМрдЬреВрдж рд░рд╣реЗрдЧрд╛ред рдЙрд╕рдХреЗ рдмрд╛рдж, рдЖрдк рдкрд╣рд▓реЗ рд╕реЗ рд╣реА рдЖрд╕рд╛рдиреА рд╕реЗ рдкрд╛ рд╕рдХрддреЗ рд╣реИрдВ  рдмреАрдЯрд╛ :

 hat Beta=(XTX)тИТ1XT mathbfy

рдмрд╕, рд╕рдорд╕реНрдпрд╛ рд╣рд▓ рд╣реЛ рдЧрдИред рдпрд╣ рд╡рд╣ рдЬрдЧрд╣ рд╣реИ рдЬрд╣рд╛рдВ рд╣рдо рдпрд╛рдж рдХрд░рддреЗ рд╣реИрдВ рдХрд┐ рд░реИрдЦрд┐рдХ рдкреНрд░рддрд┐рдЧрдорди рдХреНрдпрд╛ рд╣реИред

рдмреЗрд╢рдХ, рдбреЗрдЯрд╛ рдкреАрдврд╝реА рдХреА рдкреНрд░рдХреГрддрд┐ рдореЗрдВ рдирд┐рд╣рд┐рдд рдирд┐рд░реНрднрд░рддрд╛ рдЕрд▓рдЧ-рдЕрд▓рдЧ рд╣реЛ рд╕рдХрддреА рд╣реИ рдФрд░ рдлрд┐рд░ рд╣рдорд╛рд░реЗ рдореЙрдбрд▓ рдореЗрдВ рдкрд╣рд▓реЗ рд╕реЗ рд╣реА рдХреБрдЫ рдЧреИрд░-рд╢реБрджреНрдзрддрд╛ рдХреЛ рдЬреЛрдбрд╝рд╛ рдЬрд╛рдирд╛ рдЪрд╛рд╣рд┐рдПред рдореИрдЯреНрд░рд┐рдХреНрд╕ рдФрд░ рдмрдбрд╝реЗ рдбреЗрдЯрд╛ рдХреЗ рд▓рд┐рдП рдкреНрд░рддрд┐рдЧрдорди рд╕рдорд╕реНрдпрд╛ рдХреЛ рд╕реАрдзреЗ рд╣рд▓ рдХрд░рдирд╛ рднреА рдПрдХ рдмреБрд░рд╛ рд╡рд┐рдЪрд╛рд░ рд╣реИ, рдХреНрдпреЛрдВрдХрд┐ рдореИрдЯреНрд░рд┐рдХреНрд╕ рд╣реИ XTX рдЖрдпрд╛рдореА рд╕реНрд╡рд░реВрдк n nn , рдФрд░ рдПрдХ рдХреЛ рдЕрднреА рднреА рдЕрдкрдиреЗ рд╡реНрдпреБрддреНрдХреНрд░рдо рдореИрдЯреНрд░рд┐рдХреНрд╕ рдХреЛ рдЦреЛрдЬрдиреЗ рдХреА рдЖрд╡рд╢реНрдпрдХрддрд╛ рд╣реИ, рдФрд░ рдпрд╣ рдЕрдХреНрд╕рд░ рдРрд╕рд╛ рд╣реЛрддрд╛ рд╣реИ рдХрд┐ рдРрд╕рд╛ рдореИрдЯреНрд░рд┐рдХреНрд╕ рдмрд╕ рдореМрдЬреВрдж рдирд╣реАрдВ рд╣реЛрддрд╛ рд╣реИред рдЗрд╕ рдорд╛рдорд▓реЗ рдореЗрдВ, рдврд╛рд▓ рд╡рдВрд╢ рдкрд░ рдЖрдзрд╛рд░рд┐рдд рд╡рд┐рднрд┐рдиреНрди рд╡рд┐рдзрд┐рдпрд╛рдВ рд╣рдорд╛рд░реА рд╕рд╣рд╛рдпрддрд╛ рдХреЗ рд▓рд┐рдП рдЖрддреА рд╣реИрдВред рдореЙрдбрд▓ рдХреЗ рдЧреИрд░-рд░реИрдЦрд┐рдХрддрд╛ рдХреЛ рд╡рд┐рднрд┐рдиреНрди рддрд░реАрдХреЛрдВ рд╕реЗ рд▓рд╛рдЧреВ рдХрд┐рдпрд╛ рдЬрд╛ рд╕рдХрддрд╛ рд╣реИ, рдЬрд┐рд╕рдореЗрдВ рддрдВрддреНрд░рд┐рдХрд╛ рдиреЗрдЯрд╡рд░реНрдХ рдХрд╛ рдЙрдкрдпреЛрдЧ рдХрд░рдирд╛ рд╢рд╛рдорд┐рд▓ рд╣реИред

рд▓реЗрдХрд┐рди рдЕрдм, рдЗрд╕ рдмрд╛рд░реЗ рдореЗрдВ рдмрд╛рдд рдХрд░рддреЗ рд╣реИрдВ, рд▓реЗрдХрд┐рди рддреНрд░реБрдЯрд┐ рдХрд╛рд░реНрдпреЛрдВ рдХреЗ рдмрд╛рд░реЗ рдореЗрдВред рдЬрдм рдбреЗрдЯрд╛ рдХрд╛ рдЧреИрд░-рд░реИрдЦрд┐рдХ рд╕рдВрдмрдВрдз рд╣реЛ рд╕рдХрддрд╛ рд╣реИ, рддреЛ рдПрд╕рдИ рдФрд░ рд▓реЙрдЧ-рд▓рд┐рдХреНрд▓рд┐рд╣реБрдб рдХреЗ рдмреАрдЪ рдХреНрдпрд╛ рдЕрдВрддрд░ рд╣реИ?

рд╣рдо рдЪрд┐рдбрд╝рд┐рдпрд╛рдШрд░ рд╕реЗ рд╕рдВрдмрдВрдзрд┐рдд рд╣реИрдВ, рдЕрд░реНрдерд╛рддреН: рдУрдПрд▓рдПрд╕, рдПрд▓рдПрд╕, рдПрд╕рдИ, рдПрдордПрд╕рдИ, рдЖрд░рдПрд╕рдПрд╕
рдпрд╣ рд╕рдм рдПрдХ рдФрд░ рд╕рдорд╛рди рд░реВрдк рд╕реЗ рд╕рдорд╛рди рд╣реИ, рдЖрд░рдПрд╕рдПрд╕ - рд╡рд░реНрдЧреЛрдВ рдХрд╛ рдЕрд╡рд╢рд┐рд╖реНрдЯ рдпреЛрдЧ, рдУрдПрд▓рдПрд╕ - рд╕рд╛рдзрд╛рд░рдг рдиреНрдпреВрдирддрдо рд╡рд░реНрдЧ, рдПрд▓рдПрд╕ - рдХрдо рд╕реЗ рдХрдо рд╡рд░реНрдЧ, рдПрдордПрд╕рдИ - рдорддрд▓рдм рдЪреБрдХрддрд╛ рддреНрд░реБрдЯрд┐, рдПрд╕рдИ - рдЪреБрдХрддрд╛ рддреНрд░реБрдЯрд┐ред рд╡рд┐рднрд┐рдиреНрди рд╕реНрд░реЛрддреЛрдВ рдореЗрдВ рдЖрдк рдЕрд▓рдЧ-рдЕрд▓рдЧ рдирд╛рдо рдкрд╛ рд╕рдХрддреЗ рд╣реИрдВред рдЗрд╕рдХрд╛ рд╕рд╛рд░ рдХреЗрд╡рд▓ рдПрдХ рд╣реИ: рджреНрд╡рд┐рдШрд╛рдд рд╡рд┐рдЪрд▓рди ред рдЖрдк рдирд┐рд╢реНрдЪрд┐рдд рд░реВрдк рд╕реЗ рднреНрд░рдорд┐рдд рд╣реЛ рд╕рдХрддреЗ рд╣реИрдВ, рд▓реЗрдХрд┐рди рдЖрдкрдХреЛ рдЗрд╕рдХреА рдЖрджрдд рд╣реИред

рдпрд╣ рдзреНрдпрд╛рди рджреЗрдиреЗ рдпреЛрдЧреНрдп рд╣реИ рдХрд┐ рдПрдордПрд╕рдИ рдорд╛рдирдХ рд╡рд┐рдЪрд▓рди рд╣реИ, рд╕рдВрдкреВрд░реНрдг рдкреНрд░рд╢рд┐рдХреНрд╖рдг рдбреЗрдЯрд╛ рд╕реЗрдЯ рдХреЗ рд▓рд┐рдП рддреНрд░реБрдЯрд┐ рдХрд╛ рдПрдХ рдирд┐рд╢реНрдЪрд┐рдд рдФрд╕рдд рдореВрд▓реНрдп рд╣реИред рд╡реНрдпрд╡рд╣рд╛рд░ рдореЗрдВ, рдПрдордПрд╕рдИ рдХрд╛ рдЙрдкрдпреЛрдЧ рдЖрдорддреМрд░ рдкрд░ рдХрд┐рдпрд╛ рдЬрд╛рддрд╛ рд╣реИред рд╕реВрддреНрд░ рд╡рд┐рд╢реЗрд╖ рд░реВрдк рд╕реЗ рднрд┐рдиреНрди рдирд╣реАрдВ рд╣реИ:

MSE( Beta)= frac1N sumni=1(yiтИТ hatyi)2

рдПрди - рдбреЗрдЯрд╛рд╕реЗрдЯ рдХрд╛ рдЖрдХрд╛рд░, \ _ {y} _i - рдХреЗ рд▓рд┐рдП рдореЙрдбрд▓ рднрд╡рд┐рд╖реНрдпрд╡рд╛рдгреА yi ред

рдЗрд╕реЗ рд░реЛрдХреЛ! рд╕рдВрднрд╛рд╡рдирд╛? рдпрд╣ рдкреНрд░рд╛рдпрд┐рдХрддрд╛ рд╕рд┐рджреНрдзрд╛рдВрдд рд╕реЗ рдХреБрдЫ рд╣реИред рдпрд╣ рд╕рд╣реА рд╣реИ - рдпрд╣ рд╢реБрджреНрдз рд╕рдВрднрд╛рд╡рдирд╛ рд╕рд┐рджреНрдзрд╛рдВрдд рд╣реИред рд▓реЗрдХрд┐рди рджреНрд╡рд┐рдШрд╛рдд рд╡рд┐рдЪрд▓рди рдХреИрд╕реЗ рд╕рдВрднрд╛рд╡рд┐рдд рдлрд╝рдВрдХреНрд╢рди рд╕реЗ рд╕рдВрдмрдВрдзрд┐рдд рд╣реЛ рд╕рдХрддрд╛ рд╣реИ? рдФрд░ рдпрд╣ рдХреИрд╕реЗ рдирд┐рдХрд▓рд╛ред рдпрд╣ рдЕрдзрд┐рдХрддрдо рд╕рдВрднрд╛рд╡рдирд╛ рдЦреЛрдЬрдиреЗ рдХреЗ рд╕рд╛рде рдЬреБрдбрд╝рд╛ рд╣реБрдЖ рд╣реИ (рдЕрдзрд┐рдХрддрдо рд╕рдВрднрд╛рд╡рдирд╛) рдФрд░ рд╕рд╛рдорд╛рдиреНрдп рд╡рд┐рддрд░рдг рдХреЗ рд╕рд╛рде, рдЕрдзрд┐рдХ рд╕рдЯреАрдХ рд╣реЛрдиреЗ рдХреЗ рд▓рд┐рдП, рдЗрд╕рдХреЗ рдФрд╕рдд рдХреЗ рд╕рд╛рде  рдореБ ред

рдпрд╣ рдорд╣рд╕реВрд╕ рдХрд░рдиреЗ рдХреЗ рд▓рд┐рдП рдХрд┐ рдпрд╣ рдРрд╕рд╛ рд╣реИ, рдЪрд▓реЛ рдлрд┐рд░ рд╕реЗ рд╡рд░реНрдЧ-рд╡рд┐рдЪрд▓рди рдлрд╝рдВрдХреНрд╢рди рдХреЛ рджреЗрдЦреЗрдВ:

RSS( Beta)= sumni=1(yiтИТ hatyi)2 qquad qquad(1)

рдЕрдм рдорд╛рди рд▓реЗрдВ рдХрд┐ рд╕рдВрднрд╛рд╡рдирд╛ рдлрд╝рдВрдХреНрд╢рди рдХрд╛ рдПрдХ рд╕рд╛рдорд╛рдиреНрдп рд░реВрдк рд╣реИ, рдЕрд░реНрдерд╛рддреН рдПрдХ рдЧрд╛рдКрд╕реА рдпрд╛ рд╕рд╛рдорд╛рдиреНрдп рд╡рд┐рддрд░рдг:

L(X)=p(X(ред|Theta)= prodX mathcalN(xi; mu, sigma2)

рд╕рд╛рдорд╛рдиреНрдп рддреМрд░ рдкрд░, рд╕рдВрднрд╛рд╡рдирд╛ рд╕рдорд╛рд░реЛрд╣ рдХреНрдпрд╛ рд╣реИ рдФрд░ рдЗрд╕рдХрд╛ рдХреНрдпрд╛ рдЕрд░реНрде рд╣реИ рдпрд╣ рдореИрдВ рдирд╣реАрдВ рдмрддрд╛рдКрдВрдЧрд╛, рдЖрдк рдЗрд╕рдХреЗ рдмрд╛рд░реЗ рдореЗрдВ рдХрд╣реАрдВ рдФрд░ рдкрдврд╝ рд╕рдХрддреЗ рд╣реИрдВ, рдЖрдкрдХреЛ рдЧрд╣рди рд╕рдордЭ рдХреЗ рд▓рд┐рдП рд╕рд╢рд░реНрдд рд╕рдВрднрд╛рд╡реНрдпрддрд╛, рдмреЗрдпрд╕ рдкреНрд░рдореЗрдп рдФрд░ рдмрд╣реБрдд рдХреБрдЫ рдХреА рдЕрд╡рдзрд╛рд░рдгрд╛ рд╕реЗ рднреА рдкрд░рд┐рдЪрд┐рдд рд╣реЛрдирд╛ рдЪрд╛рд╣рд┐рдПред рдпрд╣ рд╕рдм рд╕рдВрднрд╛рд╡реНрдпрддрд╛ рдХреЗ рд╢реБрджреНрдз рд╕рд┐рджреНрдзрд╛рдВрдд рдореЗрдВ рдЬрд╛рддрд╛ рд╣реИ, рдЬрд┐рд╕рдХрд╛ рдЕрдзреНрдпрдпрди рд╕реНрдХреВрд▓ рдФрд░ рд╡рд┐рд╢реНрд╡рд╡рд┐рджреНрдпрд╛рд▓рдп рджреЛрдиреЛрдВ рдореЗрдВ рдХрд┐рдпрд╛ рдЬрд╛рддрд╛ рд╣реИред

рдЕрдм, рд╕рд╛рдорд╛рдиреНрдп рд╡рд┐рддрд░рдг рд╕реВрддреНрд░ рдХреЛ рдпрд╛рдж рдХрд░рддреЗ рд╣реБрдП, рд╣рдо рдкреНрд░рд╛рдкреНрдд рдХрд░рддреЗ рд╣реИрдВ:

L(X; mu, sigma2)= prodX frac1 sqrt2 pi sigma2eтИТ frac(xiтИТmu))22 рд╕рд┐рдЧреНрдорд╛2 qquad qquad(2)

рдХреНрдпрд╛ рд╣реЛрдЧрд╛ рдЕрдЧрд░ рд╣рдо рдорд╛рдирдХ рд╡рд┐рдЪрд▓рди рдбрд╛рд▓рддреЗ рд╣реИрдВ  рд╕рд┐рдЧреНрдорд╛2=1 рдФрд░ рд╕рднреА рд╕реНрдерд┐рд░рд╛рдВрдХ рдХреЛ рд╕реВрддреНрд░ рдореЗрдВ рд╣рдЯрд╛ рджреЗрдВ (2), рдмрд╕ рд╣рдЯрд╛ рджреЗрдВ, рдХрдо рди рдХрд░реЗрдВ, рдХреНрдпреЛрдВрдХрд┐ рдлрд╝рдВрдХреНрд╢рди рдХрд╛ рдиреНрдпреВрдирддрдо рдкрддрд╛ рд▓рдЧрд╛рдирд╛ рдЙрди рдкрд░ рдирд┐рд░реНрднрд░ рдирд╣реАрдВ рдХрд░рддрд╛ рд╣реИред рддрдм рд╣рдо рдЗрд╕реЗ рджреЗрдЦреЗрдВрдЧреЗ:

L(X; mu, sigma2) sim prodXeтИТ(xiтИТ mu)2

рдЕрднреА рднреА рдХреБрдЫ рдкрд╕рдВрдж рдирд╣реАрдВ рд╣реИ? рдирд╣реАрдВ? рдареАрдХ рд╣реИ, рдХреНрдпрд╛ рд╣реЛрдЧрд╛ рдЕрдЧрд░ рд╣рдо рдлрд╝рдВрдХреНрд╢рди рдХрд╛ рд▓рдШреБрдЧрдгрдХ рд▓реЗрддреЗ рд╣реИрдВ? рд▓рдШреБрдЧрдгрдХ рд╕реЗ, рд╕рд╛рдорд╛рдиреНрдп рддреМрд░ рдкрд░, рдХреБрдЫ рдкреНрд▓рд╕ рд╣реЛрддреЗ рд╣реИрдВ: рдЧреБрдгрд╛ рдПрдХ рдпреЛрдЧ рдореЗрдВ рдмрджрд▓ рдЬрд╛рдПрдЧрд╛, рдЧреБрдгрд╛ рдореЗрдВ рдПрдХ рдбрд┐рдЧреНрд░реА, рдФрд░  loge=1 - рдЗрд╕ рд╕рдВрдкрддреНрддрд┐ рдХреЗ рд▓рд┐рдП рдпрд╣ рд╕реНрдкрд╖реНрдЯ рдХрд░рдиреЗ рдпреЛрдЧреНрдп рд╣реИ рдХрд┐ рд╣рдо рдкреНрд░рд╛рдХреГрддрд┐рдХ рд▓рдШреБрдЧрдгрдХ рдХреЗ рдмрд╛рд░реЗ рдореЗрдВ рдмрд╛рдд рдХрд░ рд░рд╣реЗ рд╣реИрдВ рдФрд░, рд╕рдЦреНрддреА рд╕реЗ рдмреЛрд▓ рд░рд╣реЗ рд╣реИрдВ  lne=1 ред рдФрд░ рд╕рд╛рдорд╛рдиреНрдп рддреМрд░ рдкрд░, рдПрдХ рдлрд╝рдВрдХреНрд╢рди рдХрд╛ рд▓рдШреБрдЧрдгрдХ рдЗрд╕рдХреА рдЕрдзрд┐рдХрддрдо рдкрд░рд┐рд╡рд░реНрддрди рдирд╣реАрдВ рдХрд░рддрд╛ рд╣реИ, рдФрд░ рдпрд╣ рд╣рдорд╛рд░реЗ рд▓рд┐рдП рд╕рдмрд╕реЗ рдорд╣рддреНрд╡рдкреВрд░реНрдг рд╡рд┐рд╢реЗрд╖рддрд╛ рд╣реИред рд▓реЙрдЧ-рд▓рд┐рдХреЗрд▓рд┐рд╣реБрдб рдФрд░ рд▓рд┐рдХреЗрд▓рд┐рд╣реБрдб рдХреЗ рд╕рд╛рде рд╕рдВрдмрдВрдз рдФрд░ рдпрд╣ рдХреНрдпреЛрдВ рдЙрдкрдпреЛрдЧреА рд╣реЛрдЧрд╛ рдПрдХ рдЫреЛрдЯреЗ рд╕реЗ рд╡рд┐рд╖рдпрд╛рдВрддрд░ рдореЗрдВ рдиреАрдЪреЗ рд╡рд░реНрдгрд┐рдд рдХрд┐рдпрд╛ рдЬрд╛рдПрдЧрд╛ред рдФрд░ рдЗрд╕рд▓рд┐рдП рд╣рдордиреЗ рдХреНрдпрд╛ рдХрд┐рдпрд╛: рд╕рднреА рд╕реНрдерд┐рд░рд╛рдВрдХ рд╣рдЯрд╛ рджрд┐рдП, рдФрд░ рд╕рдВрднрд╛рд╡рдирд╛ рд╕рдорд╛рд░реЛрд╣ рдХрд╛ рд▓рдШреБрдЧрдгрдХ рд▓рд┐рдпрд╛ред рдЙрдиреНрд╣реЛрдВрдиреЗ рдорд╛рдЗрдирд╕ рд╕рд╛рдЗрди рдХреЛ рднреА рд╣рдЯрд╛ рджрд┐рдпрд╛, рдЗрд╕ рдкреНрд░рдХрд╛рд░ рд▓реЙрдЧ-рд▓рд┐рдХреЗрд▓рд┐рд╣реБрдб рдХреЛ рдиреЗрдЧреЗрдЯрд┐рд╡ рд▓реЙрдЧ-рд▓рд┐рдХреЗрд▓рд┐рд╣реБрдб (рдПрдирдПрд▓рдПрд▓) рдореЗрдВ рдмрджрд▓ рджрд┐рдпрд╛, рдЙрдирдХреЗ рдмреАрдЪ рдХреЗ рд╕рдВрдмрдВрдз рдХреЛ рднреА рдмреЛрдирд╕ рдХреЗ рд░реВрдк рдореЗрдВ рд╡рд░реНрдгрд┐рдд рдХрд┐рдпрд╛ рдЬрд╛рдПрдЧрд╛ред рдирддреАрдЬрддрди, рд╣рдореЗрдВ рдПрдирдПрд▓рдПрд▓ рдлрд╝рдВрдХреНрд╢рди рдорд┐рд▓рд╛:

 logL(X; mu,I2) sim sum(XтИТ mu)2

RSS рдлрд╝рдВрдХреНрд╢рди (1) рдкрд░ рдПрдХ рдФрд░ рдирдЬрд╝рд░ рдбрд╛рд▓реЗрдВред рд╣рд╛рдБ, рд╡реЗ рд╡рд╣реА рд╣реИрдВ! рдмрд┐рд▓рдХреБрд▓ рд╕рд╣реА! рдпрд╣ рднреА рджреЗрдЦрд╛ рдЬрд╛рддрд╛ рд╣реИ рдХрд┐  рдореБ= рдЯреЛрдкреАрдп ред

рдпрджрд┐ рдЖрдк MSE рдорд╛рдирдХ рд╡рд┐рдЪрд▓рди рдлрд╝рдВрдХреНрд╢рди рдХрд╛ рдЙрдкрдпреЛрдЧ рдХрд░рддреЗ рд╣реИрдВ, рддреЛ рд╣рдо рдЗрд╕рд╕реЗ рдкреНрд░рд╛рдкреНрдд рдХрд░рддреЗ рд╣реИрдВ:

 operatornameargminMSE( beta) sim operatornameargmax mathbbEX simPdata logPрдореЙрдбрд▓(x; beta)

рдЬрд╣рд╛рдБ  mathbbE - рдЧрдгрд┐рддреАрдп рдЕрдкреЗрдХреНрд╖рд╛  рдмреАрдЯрд╛ - рдореЙрдбрд▓ рдкреИрд░рд╛рдореАрдЯрд░, рднрд╡рд┐рд╖реНрдп рдореЗрдВ рд╣рдо рдЙрдиреНрд╣реЗрдВ рдирд┐рдореНрди рд░реВрдк рдореЗрдВ рдирд┐рд░реВрдкрд┐рдд рдХрд░реЗрдВрдЧреЗ: $рдереАрдЯрд╛ ред

рдирд┐рд╖реНрдХрд░реНрд╖: рдпрджрд┐ рд╣рдо рдкреНрд░рддрд┐рдЧрдорди рдкреНрд░рд╢реНрди рдореЗрдВ рдПрд▓рдПрд╕ рдкрд░рд┐рд╡рд╛рд░ рдХреЛ рддреНрд░реБрдЯрд┐ рдХрд╛рд░реНрдпреЛрдВ рдХреЗ рд░реВрдк рдореЗрдВ рдЙрдкрдпреЛрдЧ рдХрд░рддреЗ рд╣реИрдВ, рддреЛ рд╣рдо рдЕрдирд┐рд╡рд╛рд░реНрдп рд░реВрдк рд╕реЗ рдЗрд╕ рдорд╛рдорд▓реЗ рдореЗрдВ рдЕрдзрд┐рдХрддрдо рд╕рдВрднрд╛рд╡рдирд╛ рдлрд╝рдВрдХреНрд╢рди рдХреЛ рдЦреЛрдЬрдиреЗ рдХреА рд╕рдорд╕реНрдпрд╛ рдХреЛ рд╣рд▓ рдХрд░рддреЗ рд╣реИрдВ рдЬрдм рд╡рд┐рддрд░рдг рдЧреЙрд╕рд┐рдпрди рд╣реЛрддрд╛ рд╣реИред рдФрд░ рдЕрдиреБрдорд╛рдирд┐рдд рдореВрд▓реНрдп  y рд╕рд╛рдорд╛рдиреНрдп рд╡рд┐рддрд░рдг рдореЗрдВ рдФрд╕рдд рдХреЗ рдмрд░рд╛рдмрд░ред рдФрд░ рдЕрдм рд╣рдо рдЬрд╛рдирддреЗ рд╣реИрдВ рдХрд┐ рдпрд╣ рд╕рдм рдХреИрд╕реЗ рдЬреБрдбрд╝рд╛ рд╣реБрдЖ рд╣реИ, рдХреИрд╕реЗ рд╕рдВрднрд╛рд╡реНрдпрддрд╛ рд╕рд┐рджреНрдзрд╛рдВрдд (рдЗрд╕рдХреА рд╕рдВрднрд╛рд╡рдирд╛ рд╕рдорд╛рд░реЛрд╣ рдФрд░ рд╕рд╛рдорд╛рдиреНрдп рд╡рд┐рддрд░рдг рдХреЗ рд╕рд╛рде) рдФрд░ рдорд╛рдирдХ рд╡рд┐рдЪрд▓рди рдпрд╛ рдУрдПрд▓рдПрд╕ рдХреЗ рддрд░реАрдХреЗ рдЬреБрдбрд╝реЗ рд╣реБрдП рд╣реИрдВред рдЗрд╕рдХреЗ рдмрд╛рд░реЗ рдореЗрдВ рдЕрдзрд┐рдХ рдЬрд╛рдирдХрд╛рд░реА [2] рдореЗрдВ рдорд┐рд▓ рд╕рдХрддреА рд╣реИред

рдФрд░ рдпрд╣рд╛рдБ рд╡рд╛рджрд╛ рдХрд┐рдпрд╛ рдЧрдпрд╛ рдмреЛрдирд╕ рд╣реИред рдЪреВрдВрдХрд┐ рд╣рдо рд╡рд┐рднрд┐рдиреНрди рддреНрд░реБрдЯрд┐ рдХрд╛рд░реНрдпреЛрдВ рдХреЗ рдмреАрдЪ рд╕рдВрдмрдВрдзреЛрдВ рдХреЗ рдмрд╛рд░реЗ рдореЗрдВ рдмрд╛рдд рдХрд░ рд░рд╣реЗ рд╣реИрдВ, рд╣рдо рд╡рд┐рдЪрд╛рд░ рдХрд░реЗрдВрдЧреЗ (рдкрдврд╝рдиреЗ рдХреЗ рд▓рд┐рдП рдЖрд╡рд╢реНрдпрдХ рдирд╣реАрдВ):

рдХреНрд░реЙрд╕-рдПрдиреНрдЯреНрд░реЙрдкреА, рд▓рд┐рдХреЗрд▓рд┐рд╣реБрдб, рд▓реЙрдЧ-рд▓рд┐рдХреЗрд▓рд┐рд╣реБрдб рдФрд░ рдиреЗрдЧреЗрдЯрд┐рд╡ рд▓реЙрдЧ-рд▓рд┐рдХреЗрд▓рд┐рд╣реБрдб рдХреЗ рдмреАрдЪ рд╕рдВрдмрдВрдз
рдорд╛рди рд▓реАрдЬрд┐рдП рд╣рдорд╛рд░реЗ рдкрд╛рд╕ рдбреЗрдЯрд╛ рд╣реИ X = \ {x_1, x_2, x_3, x_4, ... \} , рдкреНрд░рддреНрдпреЗрдХ рдмрд┐рдВрджреБ рдПрдХ рд╡рд┐рд╢рд┐рд╖реНрдЯ рд╡рд░реНрдЧ рдХрд╛ рд╣реИ, рдЙрджрд╛рд╣рд░рдг рдХреЗ рд▓рд┐рдП \ {x_1 \ rightarrow1, x_2 \ rightarrow2, x_3 \ rightarrow n, ... \} ред рдХреБрд▓ рд╡рд╣рд╛рдБ рдПрди рдХрдХреНрд╖рд╛рдПрдВ, рдЬрдмрдХрд┐ рдХрдХреНрд╖рд╛ 1 рд╣реЛрддреА рд╣реИрдВ c1 рд╕рдордп, рдХрдХреНрд╖рд╛ 2 - c2 рд╕рдордп рдФрд░ рдХрдХреНрд╖рд╛ рдПрди - cn рд╕рдордпред рдЗрд╕ рдбреЗрдЯрд╛ рдкрд░ рд╣рдордиреЗ рдХреБрдЫ рдореЙрдбрд▓ рдХреЛ рдкреНрд░рд╢рд┐рдХреНрд╖рд┐рдд рдХрд┐рдпрд╛ $рдереАрдЯрд╛ ред рдЗрд╕рдХреЗ рд▓рд┐рдП рд╕рдВрднрд╛рд╡рдирд╛ рд╕рдорд╛рд░реЛрд╣ (рд╕рдВрднрд╛рд╡рдирд╛) рдЗрд╕ рддрд░рд╣ рджрд┐рдЦреЗрдЧрд╛:

P(рдбреЗрдЯрд╛| theta)=P(0,1,...,nред Theta)=P(0|редTheрдереАрдЯрд╛)P(1| theрдереАрдЯрд╛)...P(n| theрдереАрдЯрд╛)

P(1| theta)P(2| theta)...P(nред TheрдереА)= prodc1 haty1 prodc2 haty2... prodcn hatyn= hatyc11 hatyc22... hatycnn


рдЬрд╣рд╛рдБ P(n| theta)= hatyn - рдХрдХреНрд╖рд╛ рдХреЗ рд▓рд┐рдП рд╕рдВрднрд╛рд╡рдирд╛ рдХреА рднрд╡рд┐рд╖реНрдпрд╡рд╛рдгреА рдХреА рдПрди ред

рд╣рдо рд╕рдВрднрд╛рд╡рдирд╛ рдлрд╝рдВрдХреНрд╢рди рдХреЗ рд▓рдШреБрдЧрдгрдХ рдХреЛ рд▓реЗрддреЗ рд╣реИрдВ рдФрд░ рд▓реЙрдЧ-рд▓рд╛рдЗрдХреЗрд▓рд┐рд╣реБрдб рдкреНрд░рд╛рдкреНрдд рдХрд░рддреЗ рд╣реИрдВ:

\ log {P (рдбреЗрдЯрд╛ред \ theta)} = \ рд▓реЙрдЧ {(\ hat {y} _1 ^ {c_1} ... \ hat {y} _n ^ {c_n})} = c_1 \ log {\ hat {{ y_1}} + ... + c_n \ log {\ hat {y_n}} = \ sum_i ^ n {c_i \ log {\ hat {y_i}}}

рд╕рдВрднрд╛рд╡рдирд╛ 0рдЯреЛрдкреАy рдореЗрдВ[0,1] рд╕рдВрднрд╛рд╡рдирд╛ рдХреА рдкрд░рд┐рднрд╛рд╖рд╛ рдХреЗ рдЖрдзрд╛рд░ рдкрд░, 0 рд╕реЗ 1 рддрдХ рдХреА рд╕реАрдорд╛ рдореЗрдВ рд╣реИред рдЗрд╕рд▓рд┐рдП, рд▓рдШреБрдЧрдгрдХ рдХрд╛ рдирдХрд╛рд░рд╛рддреНрдордХ рдорд╛рди рд╣реЛрдЧрд╛ред рдФрд░ рдЕрдЧрд░ рд╣рдо рд▓реЙрдЧ-рд▓рд┐рдХреНрд▓рд┐рд╣реБрдб рдХреЛ 1 рд╕реЗ рдЧреБрдгрд╛ рдХрд░рддреЗ рд╣реИрдВ, рддреЛ рд╣рдореЗрдВ рдлрдВрдХреНрд╢рди рдирд┐рдЧреЗрдЯрд┐рд╡ рд▓реЙрдЧ-рд▓рд┐рдХреНрд▓рд┐рд╣реБрдб (NLL) рдорд┐рд▓рддрд╛ рд╣реИ:

рдПрдирдПрд▓рдПрд▓=тИТ рд▓реЙрдЧрдкреА(рдбреЗрдЯрд╛ред рдереАрдЯрд╛)=тИТ sumnici log hatyi

рдпрджрд┐ рд╣рдо рдПрдирдПрд▓рдПрд▓ рдХреЛ рдЕрдВрдХреЛрдВ рдХреА рд╕рдВрдЦреНрдпрд╛ рд╕реЗ рд╡рд┐рднрд╛рдЬрд┐рдд рдХрд░рддреЗ рд╣реИрдВ X ред N=c1+c2+...+cn рддрдм рд╣рдореЗрдВ рдорд┐рд▓рддрд╛ рд╣реИ:

- \ frac {1} {N} \ log {P (dataред theta)} = - \ _ sum_i ^ n {\ frac {c_i} {N} \ log {\ hat {y_i}}

рдпрд╣ рдзреНрдпрд╛рди рджрд┐рдпрд╛ рдЬрд╛ рд╕рдХрддрд╛ рд╣реИ рдХрд┐ рдХрдХреНрд╖рд╛ рдХреЗ рд▓рд┐рдП рд╡рд╛рд╕реНрддрд╡рд┐рдХ рд╕рдВрднрд╛рд╡рдирд╛ рдПрди рдХреЗ рдмрд░рд╛рдмрд░ рд╣реИ: yn= fraccnN ред рдпрд╣рд╛рдБ рд╕реЗ рд╣рдореЗрдВ рдорд┐рд▓рддрд╛ рд╣реИ:

NLL=тИТ sumniyi log hatyi

рдЕрдм рдпрджрд┐ рдЖрдк рдХреНрд░реЙрд╕ рдПрдиреНрдЯреНрд░реЙрдкреА рдХреА рдкрд░рд┐рднрд╛рд╖рд╛ рдХреЛ рджреЗрдЦрддреЗ рд╣реИрдВ H(p,q)=тИТ sump logq рддрдм рд╣рдореЗрдВ рдорд┐рд▓рддрд╛ рд╣реИ:

рдПрдирдПрд▓рдПрд▓=рдПрдЪ(yi, hatyi)

рдорд╛рдорд▓реЗ рдореЗрдВ рдЬрдм рд╣рдорд╛рд░реЗ рдкрд╛рд╕ рдХреЗрд╡рд▓ рджреЛ рд╡рд░реНрдЧ рд╣реИрдВ n=2 (рдмрд╛рдЗрдирд░реА рд╡рд░реНрдЧреАрдХрд░рдг) рд╣рдореЗрдВ рдмрд╛рдЗрдирд░реА рдХреНрд░реЙрд╕ рдПрдиреНрдЯреНрд░реЙрдкреА рдХрд╛ рд╕реВрддреНрд░ рдорд┐рд▓рддрд╛ рд╣реИ (рдЖрдк рдкреНрд░рд╕рд┐рджреНрдз рдирд╛рдо рд▓реЙрдЧ-рд▓реЙрд╕ рд╕реЗ рднреА рдорд┐рд▓ рд╕рдХрддреЗ рд╣реИрдВ):

H (y, \ hat {y}) = - (y \ log {\ hat {y}} + + (1-y) \ log {(1- \ hat {y}})}

рдЗрд╕ рд╕рдм рд╕реЗ, рдпрд╣ рд╕рдордЭрд╛ рдЬрд╛ рд╕рдХрддрд╛ рд╣реИ рдХрд┐ рдХреБрдЫ рдорд╛рдорд▓реЛрдВ рдореЗрдВ рдХреНрд░реЙрд╕-рдПрдиреНрдЯреНрд░реЙрдкреА рдХреЛ рдХрдо рдХрд░рдирд╛ рдПрдирдПрд▓рдПрд▓ рдХреЛ рдХрдо рдХрд░рдиреЗ рдпрд╛ рд╕рдВрднрд╛рд╡рдирд╛ рдлрд╝рдВрдХреНрд╢рди (рд▓рд┐рдХреЗрд▓рд┐рд╣реБрдб) рдпрд╛ рд▓реЙрдЧ-рд▓рд┐рдХреЗрд▓рд┐рд╣реБрдб рдХрд╛ рдЕрдзрд┐рдХрддрдо рдкрддрд╛ рд▓рдЧрд╛рдиреЗ рдХреЗ рдмрд░рд╛рдмрд░ рд╣реИред

рдПрдХ рдЙрджрд╛рд╣рд░рдг рд╣реИред рдПрдХ рджреНрд╡рд┐рдЖрдзрд╛рд░реА рд╡рд░реНрдЧреАрдХрд░рдг рдкрд░ рд╡рд┐рдЪрд╛рд░ рдХрд░реЗрдВред рд╣рдорд╛рд░реЗ рдкрд╛рд╕ рд╡рд░реНрдЧ рдореВрд▓реНрдп рд╣реИрдВ:

y = np.array([0, 1, 1, 1, 1, 0, 1, 1]).astype(np.float32) 

рд╡рд╛рд╕реНрддрд╡рд┐рдХ рд╕рдВрднрд╛рд╡рдирд╛ рдп рд╡рд░реНрдЧ 0 рдХреЗ рд▓рд┐рдП рдмрд░рд╛рдмрд░ рд╣реИ 2/8=0.25 , рдХрдХреНрд╖рд╛ 1 рдХреЗ рд▓рд┐рдП рдмрд░рд╛рдмрд░ рд╣реИ 6/8=0.75 ред рдорд╛рди рд▓реАрдЬрд┐рдП рдХрд┐ рд╣рдорд╛рд░реЗ рдкрд╛рд╕ рдПрдХ рдмрд╛рдЗрдирд░реА рдХреНрд▓рд╛рд╕рд┐рдлрд╛рдпрд░рд┐рдпрд░ рд╣реИ рдЬреЛ рдХреНрд▓рд╛рд╕ 0 рдХреА рд╕рдВрднрд╛рд╡рдирд╛ рдХреА рднрд╡рд┐рд╖реНрдпрд╡рд╛рдгреА рдХрд░рддрд╛ рд╣реИ  y рдкреНрд░рддреНрдпреЗрдХ рдЙрджрд╛рд╣рд░рдг рдХреЗ рд▓рд┐рдП, рдХреНрд░рдорд╢рдГ, рдХрдХреНрд╖рд╛ 1 рдХреЗ рд▓рд┐рдП, рд╕рдВрднрд╛рд╡рдирд╛ рд╣реИ (1тИТрд╣реИрдЯy) ред рдЖрдЗрдП рд╡рд┐рднрд┐рдиреНрди рдкреВрд░реНрд╡рд╛рдиреБрдорд╛рдиреЛрдВ рдХреЗ рд▓рд┐рдП рд▓реЙрдЧ-рд▓реЙрд╕ рдлрд╝рдВрдХреНрд╢рди рдХреЗ рдореВрд▓реНрдпреЛрдВ рдХреЛ рдкреНрд▓реЙрдЯ рдХрд░реЗрдВ  y :


рдЧреНрд░рд╛рдлрд╝ рдкрд░ рдЖрдк рджреЗрдЦ рд╕рдХрддреЗ рд╣реИрдВ рдХрд┐ рд▓реЙрдЧ-рд▓реЙрд╕ рдлрд╝рдВрдХреНрд╢рди рдХрд╛ рдиреНрдпреВрдирддрдо рдмрд┐рдВрджреБ 0.75 рд╕реЗ рдореЗрд▓ рдЦрд╛рддрд╛ рд╣реИ, рдЕрд░реНрдерд╛рддред рдпрджрд┐ рд╣рдорд╛рд░рд╛ рдореЙрдбрд▓ рдкреВрд░реА рддрд░рд╣ рд╕реЗ рд╕реНрд░реЛрдд рдбреЗрдЯрд╛ рдХреЗ рд╡рд┐рддрд░рдг рдХреЛ "рд╕реАрдЦрд╛" рд╣реИ,  y=y ред

рддрдВрддреНрд░рд┐рдХрд╛ рдиреЗрдЯрд╡рд░реНрдХ рдкреНрд░рддрд┐рдЧрдорди


рдЗрд╕рд▓рд┐рдП рд╣рдо рдПрдХ рдФрд░ рджрд┐рд▓рдЪрд╕реНрдк рдЕрднреНрдпрд╛рд╕ рдореЗрдВ рдЖрдПред рдЖрдЗрдП рджреЗрдЦреЗрдВ рдХрд┐ рдЖрдк рддрдВрддреНрд░рд┐рдХрд╛ рдиреЗрдЯрд╡рд░реНрдХ (рддрдВрддреНрд░рд┐рдХрд╛ рдиреЗрдЯрд╡рд░реНрдХ) рдХрд╛ рдЙрдкрдпреЛрдЧ рдХрд░рдХреЗ рдкреНрд░рддрд┐рдЧрдорди рдХреА рд╕рдорд╕реНрдпрд╛ рдХреЛ рдХреИрд╕реЗ рд╣рд▓ рдХрд░ рд╕рдХрддреЗ рд╣реИрдВред рд╣рдо рдкрд╛рдпрдерди рдкреНрд░реЛрдЧреНрд░рд╛рдорд┐рдВрдЧ рднрд╛рд╖рд╛ рдореЗрдВ рд╕рдм рдХреБрдЫ рд▓рд╛рдЧреВ рдХрд░реЗрдВрдЧреЗ, рдПрдХ рдиреЗрдЯрд╡рд░реНрдХ рдмрдирд╛рдиреЗ рдХреЗ рд▓рд┐рдП рд╣рдо PyTorch рдбреАрдк рд▓рд░реНрдирд┐рдВрдЧ рд▓рд╛рдЗрдмреНрд░реЗрд░реА рдХрд╛ рдЙрдкрдпреЛрдЧ рдХрд░рддреЗ рд╣реИрдВред

рд╕реНрд░реЛрдд рдбреЗрдЯрд╛ рдкреАрдврд╝реА


рдЗрдирдкреБрдЯ рдбреЗрдЯрд╛  mathbfX in mathbbRN рдПрдХ рд╕рдорд╛рди рд╡рд┐рддрд░рдг рдХрд╛ рдЙрдкрдпреЛрдЧ рдХрд░рдХреЗ рдЙрддреНрдкрдиреНрди рдХрд░реЗрдВ, рдЕрдВрддрд░рд╛рд▓ -15 рд╕реЗ 15 рддрдХ рд▓реЗ рдЬрд╛рдПрдВ, U [-15, 15] $ рдореЗрдВ $ \ mathbf {X} ред рдЕрдВрдХ  mathbfY рд╣рдо рд╕рдореАрдХрд░рдг рдХрд╛ рдЙрдкрдпреЛрдЧ рдХрд░ рдкреНрд░рд╛рдкреНрдд рдХрд░рддреЗ рд╣реИрдВ:

 mathbfY=0.5 mathbfX+8 sin(0.3 mathbfX)+рд╢реЛрд░ qquad qquad(3)

рдЬрд╣рд╛рдБ рд╢реЛрд░ рдЖрдпрд╛рдо рдХрд╛ рдПрдХ рд╢реЛрд░ рд╡реЗрдХреНрдЯрд░ рд╣реИ рдПрди рдорд╛рдкрджрдВрдбреЛрдВ рдХреЗ рд╕рд╛рде рд╕рд╛рдорд╛рдиреНрдп рд╡рд┐рддрд░рдг рдХрд╛ рдЙрдкрдпреЛрдЧ рдХрд░рдХреЗ рдкреНрд░рд╛рдкреНрдд:  mu=0, sigma2=1 ред

рдбреЗрдЯрд╛ рдЬрдирд░реЗрд╢рди
 N = 3000 #   IN_DIM = 1 OUT_DIM = IN_DIM x = np.random.uniform(-15., 15., (IN_DIM, N)).T.astype(np.float32) noise = np.random.normal(size=(N, 1)).astype(np.float32) y = 0.5*x+ 8.*np.sin(0.3*x) + noise #  3 x_train, x_test, y_train, y_test = train_test_split(x, y) #      



рдкреНрд░рд╛рдкреНрдд рдЖрдВрдХрдбрд╝реЛрдВ рдХрд╛ рдЧреНрд░рд╛рдлред

рдиреЗрдЯрд╡рд░реНрдХ рдмрд┐рд▓реНрдбрд┐рдВрдЧ


рдПрдХ рдирд┐рдпрдорд┐рдд рдлрд╝реАрдб рдлреЙрд░рд╡рд░реНрдб рдиреНрдпреВрд░рд▓ рдиреЗрдЯрд╡рд░реНрдХ рдпрд╛ FFNN рдмрдирд╛рдПрдВред

рдПрдлрдПрдлрдПрдирдПрди рдХрд╛ рдирд┐рд░реНрдорд╛рдг
 class Net(nn.Module): def __init__(self, input_dim=IN_DIM, out_dim=OUT_DIM, layer_size=40): super(Net, self).__init__() self.fc = nn.Linear(input_dim, layer_size) self.logit = nn.Linear(layer_size, out_dim) def forward(self, x): x = F.tanh(self.fc(x)) #  4 x = self.logit(x) return x 


рд╣рдорд╛рд░реЗ рдиреЗрдЯрд╡рд░реНрдХ рдореЗрдВ 40 рдиреНрдпреВрд░реЙрдиреНрд╕ рдХреЗ рдЖрдпрд╛рдо рдХреЗ рд╕рд╛рде рдФрд░ рд╕рдХреНрд░рд┐рдпрдг рдлрд╝рдВрдХреНрд╢рди рдХреЗ рд╕рд╛рде рдПрдХ рдЫрд┐рдкреА рд╣реБрдИ рдкрд░рдд рд╣реЛрддреА рд╣реИ - рд╣рд╛рдЗрдкрд░рдмреЛрд▓рд┐рдХ рд╕реНрдкрд░реНрд╢рд░реЗрдЦрд╛:

 tanhx= fracexтИТeтИТxex+eтИТx qquad qquad(4)

рдЖрдЙрдЯрдкреБрдЯ рдкрд░рдд рдПрдХ рд╕рдХреНрд░рд┐рдпрдг рдлрд╝рдВрдХреНрд╢рди рдХреЗ рдмрд┐рдирд╛ рдПрдХ рд╕рд╛рдорд╛рдиреНрдп рд░реИрдЦрд┐рдХ рдкрд░рд┐рд╡рд░реНрддрди рд╣реИред

рд╕реАрдЦрдирд╛ рдФрд░ рдкрд░рд┐рдгрд╛рдо рдкреНрд░рд╛рдкреНрдд рдХрд░рдирд╛


рдПрдХ рдЕрдиреБрдХреВрд▓рдХ рдХреЗ рд░реВрдк рдореЗрдВ рд╣рдо AdamOptimizer рдХрд╛ рдЙрдкрдпреЛрдЧ рдХрд░реЗрдВрдЧреЗред рдЕрдзреНрдпрдпрди рдХреЗ рдпреБрдЧреЛрдВ рдХреА рд╕рдВрдЦреНрдпрд╛ = 2000, рд╕реАрдЦрдиреЗ рдХреА рджрд░ (рд╕реАрдЦрдиреЗ рдХреА рджрд░ рдпрд╛ lr) = 0.1ред

рдПрдлрдПрдлрдПрдирдПрди рдкреНрд░рд╢рд┐рдХреНрд╖рдг
 def train(net, x_train, y_train, x_test, y_test, epoches=2000, lr=0.1): criterion = nn.MSELoss() optimizer = optim.Adam(net.parameters(), lr=lr) N_EPOCHES = epoches BS = 1500 n_batches = int(np.ceil(x_train.shape[0] / BS)) train_losses = [] test_losses = [] for i in range(N_EPOCHES): for bi in range(n_batches): x_batch, y_batch = fetch_batch(x_train, y_train, bi, BS) x_train_var = Variable(torch.from_numpy(x_batch)) y_train_var = Variable(torch.from_numpy(y_batch)) optimizer.zero_grad() outputs = net(x_train_var) loss = criterion(outputs, y_train_var) loss.backward() optimizer.step() with torch.no_grad(): x_test_var = Variable(torch.from_numpy(x_test)) y_test_var = Variable(torch.from_numpy(y_test)) outputs = net(x_test_var) test_loss = criterion(outputs, y_test_var) test_losses.append(test_loss.item()) train_losses.append(loss.item()) if i%100 == 0: sys.stdout.write('\r Iter: %d, test loss: %.5f, train loss: %.5f' %(i, test_loss.item(), loss.item())) sys.stdout.flush() return train_losses, test_losses net = Net() train_losses, test_losses = train(net, x_train, y_train, x_test, y_test) 


рдЕрдм рд╕реАрдЦрдиреЗ рдХреЗ рдкрд░рд┐рдгрд╛рдореЛрдВ рдкрд░ рдирдЬрд░ рдбрд╛рд▓рддреЗ рд╣реИрдВред


рдкреНрд░рд╢рд┐рдХреНрд╖рдг рдХреА рдкреБрдирд░рд╛рд╡реГрддреНрддрд┐ рдХреЗ рдЖрдзрд╛рд░ рдкрд░ MSE рдлрд╝рдВрдХреНрд╢рди рдорд╛рдиреЛрдВ рдХрд╛ рдЧреНрд░рд╛рдлрд╝, рдкреНрд░рд╢рд┐рдХреНрд╖рдг рдбреЗрдЯрд╛ рдФрд░ рдкрд░реАрдХреНрд╖рдг рдбреЗрдЯрд╛ рдХреЗ рд▓рд┐рдП рдореВрд▓реНрдпреЛрдВ рдХрд╛ рдЧреНрд░рд╛рдлрд╝ред


рдкрд░реАрдХреНрд╖рдг рдбреЗрдЯрд╛ рдкрд░ рд╡рд╛рд╕реНрддрд╡рд┐рдХ рдФрд░ рдЕрдиреБрдорд╛рдирд┐рдд рдкрд░рд┐рдгрд╛рдоред

рдЙрд▓рдЯрд╛ рдбреЗрдЯрд╛


рд╣рдо рдХрд╛рд░реНрдп рдХреЛ рдЬрдЯрд┐рд▓ рдХрд░рддреЗ рд╣реИрдВ рдФрд░ рдбреЗрдЯрд╛ рдХреЛ рдЙрд▓реНрдЯрд╛ рдХрд░рддреЗ рд╣реИрдВред

рдбреЗрдЯрд╛ рдЙрд▓рдЯрд╛
 x_train_inv = y_train y_train_inv = x_train x_test_inv = y_train y_test_inv = x_train 



рдЙрд▓рдЯрд╛ рдбреЗрдЯрд╛ рдЧреНрд░рд╛рдлред

рднрд╡рд┐рд╖реНрдпрд╡рд╛рдгреА рдХреЗ рд▓рд┐рдП  mathbf hatY рдЪрд▓реЛ рдкрд┐рдЫрд▓реЗ рдЕрдиреБрднрд╛рдЧ рд╕реЗ рдкреНрд░рддреНрдпрдХреНрд╖ рд╡рд┐рддрд░рдг рдиреЗрдЯрд╡рд░реНрдХ рдХрд╛ рдЙрдкрдпреЛрдЧ рдХрд░реЗрдВ рдФрд░ рджреЗрдЦреЗрдВ рдХрд┐ рдпрд╣ рдХреИрд╕реЗ рд╕рдВрднрд╛рд▓рддрд╛ рд╣реИред

 inv_train_losses, inv_test_losses = train(net, x_train_inv, y_train_inv, x_test_inv, y_test_inv) 


рдкреНрд░рд╢рд┐рдХреНрд╖рдг рдХреА рдкреБрдирд░рд╛рд╡реГрддреНрддрд┐ рдХреЗ рдЖрдзрд╛рд░ рдкрд░ MSE рдлрд╝рдВрдХреНрд╢рди рдорд╛рдиреЛрдВ рдХрд╛ рдЧреНрд░рд╛рдлрд╝, рдкреНрд░рд╢рд┐рдХреНрд╖рдг рдбреЗрдЯрд╛ рдФрд░ рдкрд░реАрдХреНрд╖рдг рдбреЗрдЯрд╛ рдХреЗ рд▓рд┐рдП рдореВрд▓реНрдпреЛрдВ рдХрд╛ рдЧреНрд░рд╛рдлрд╝ред


рдкрд░реАрдХреНрд╖рдг рдбреЗрдЯрд╛ рдкрд░ рд╡рд╛рд╕реНрддрд╡рд┐рдХ рдФрд░ рдЕрдиреБрдорд╛рдирд┐рдд рдкрд░рд┐рдгрд╛рдоред

рдЬреИрд╕рд╛ рдХрд┐ рдЖрдк рдКрдкрд░ рджрд┐рдП рдЧрдП рдЧреНрд░рд╛рдлрд╝ рд╕реЗ рджреЗрдЦ рд╕рдХрддреЗ рд╣реИрдВ, рд╣рдорд╛рд░реЗ рдиреЗрдЯрд╡рд░реНрдХ рдиреЗ рдЗрд╕ рддрд░рд╣ рдХреЗ рдбреЗрдЯрд╛ рдХрд╛ рдмрд┐рд▓реНрдХреБрд▓ рднреА рд╕рд╛рдордирд╛ рдирд╣реАрдВ рдХрд┐рдпрд╛ рд╣реИ , рдпрд╣ рдХреЗрд╡рд▓ рдЙрдирдХреА рднрд╡рд┐рд╖реНрдпрд╡рд╛рдгреА рдирд╣реАрдВ рдХрд░ рд╕рдХрддрд╛ рд╣реИред рдФрд░ рдпрд╣ рд╕рдм рдЗрд╕рд▓рд┐рдП рд╣реБрдЖ рдХреНрдпреЛрдВрдХрд┐ рдПрдХ рдмрд┐рдВрджреБ рдХреЗ рд▓рд┐рдП рдРрд╕реА рдЙрд▓рдЯреА рд╕рдорд╕реНрдпрд╛ рдереА x рдХрдИ рдмрд┐рдВрджреБрдУрдВ рдХреЗ рдЕрдиреБрд░реВрдк рд╣реЛ рд╕рдХрддрд╛ рд╣реИ рдп ред рдЖрдк рдкреВрдЫрддреЗ рд╣реИрдВ, рд╢реЛрд░ рдХреЗ рдмрд╛рд░реЗ рдореЗрдВ рдХреНрдпрд╛? рдЙрдиреНрд╣реЛрдВрдиреЗ рдПрдХ рд╕реНрдерд┐рддрд┐ рднреА рдмрдирд╛рдИ рдЬрд┐рд╕рдореЗрдВ рдПрдХ рдХреЗ рд▓рд┐рдП x рдХреБрдЫ рдореВрд▓реНрдп рдорд┐рд▓ рд╕рдХрддрд╛ рд╣реИ рдп ред рд╣рд╛рдБ, рдпрд╣ рд╕рд╣реА рд╣реИред рд▓реЗрдХрд┐рди рдкреВрд░реЗ рдмрд┐рдВрджреБ рдпрд╣ рд╣реИ рдХрд┐ рд╢реЛрд░ рдХреЗ рдмрд╛рд╡рдЬреВрдж, рдпрд╣ рд╕рднреА рдПрдХ рдирд┐рд╢реНрдЪрд┐рдд рд╡рд┐рддрд░рдг рдерд╛ред рдФрд░ рдЪреВрдВрдХрд┐ рд╣рдорд╛рд░реЗ рдореЙрдбрд▓ рдиреЗ рдЕрдирд┐рд╡рд╛рд░реНрдп рд░реВрдк рд╕реЗ рднрд╡рд┐рд╖реНрдпрд╡рд╛рдгреА рдХреА рдереА p(y|x) , рдФрд░ рдПрдордПрд╕рдИ рдХреЗ рдорд╛рдорд▓реЗ рдореЗрдВ рдпрд╣ рд╕рд╛рдорд╛рдиреНрдп рд╡рд┐рддрд░рдг рдХреЗ рд▓рд┐рдП рдФрд╕рдд рдореВрд▓реНрдп рдерд╛ (рдпрд╣ рд▓реЗрдЦ рдХреЗ рдкрд╣рд▓реЗ рднрд╛рдЧ рдореЗрдВ рдХреНрдпреЛрдВ рд╡рд░реНрдгрд┐рдд рд╣реИ), рддреЛ рдпрд╣ "рдкреНрд░рддреНрдпрдХреНрд╖" рдХрд╛рд░реНрдп рдХреЗ рд╕рд╛рде рдЕрдЪреНрдЫреА рддрд░рд╣ рд╕реЗ рдореБрдХрд╛рдмрд▓рд╛ рдХрд┐рдпрд╛ред рдЕрдиреНрдпрдерд╛, рд╣рдо рдПрдХ рдХреЗ рд▓рд┐рдП рдХрдИ рдЕрд▓рдЧ-рдЕрд▓рдЧ рд╡рд┐рддрд░рдг рдкреНрд░рд╛рдкреНрдд рдХрд░рддреЗ рд╣реИрдВ x рдФрд░ рддрджрдиреБрд╕рд╛рд░ рд╣рдореЗрдВ рдХреЗрд╡рд▓ рдПрдХ рд╕рд╛рдорд╛рдиреНрдп рд╡рд┐рддрд░рдг рдХреЗ рд╕рд╛рде рдПрдХ рдЕрдЪреНрдЫрд╛ рдкрд░рд┐рдгрд╛рдо рдирд╣реАрдВ рдорд┐рд▓ рд╕рдХрддрд╛ рд╣реИред

рдорд┐рд╢реНрд░рдг рдШрдирддреНрд╡ рдиреЗрдЯрд╡рд░реНрдХ


рдордЬрд╝рд╛ рд╢реБрд░реВ рд╣реЛрддрд╛ рд╣реИ! рдорд┐рдХреНрд╕рдЪрд░ рдбреЗрдВрд╕рд┐рдЯреА рдиреЗрдЯрд╡рд░реНрдХ (рдЗрд╕рдХреЗ рдмрд╛рдж рдПрдордбреАрдПрди рдпрд╛ рдПрдордбреА рдиреЗрдЯрд╡рд░реНрдХ) рдХреНрдпрд╛ рд╣реИ? рд╕рд╛рдорд╛рдиреНрдп рддреМрд░ рдкрд░, рдпрд╣ рдПрдХ рдирд┐рд╢реНрдЪрд┐рдд рдореЙрдбрд▓ рд╣реИ рдЬреЛ рдПрдХ рд╕рд╛рде рдХрдИ рд╡рд┐рддрд░рдгреЛрдВ рдХреЛ рдореЙрдбрд▓ рдХрд░рдиреЗ рдореЗрдВ рд╕рдХреНрд╖рдо рд╣реИ:

p (\ mathbf {y} | \ mathbf {x}; theta) = \ sum_k ^ K \ pi_k (\ mathbf {x}) \ mathcal {N} (\ mathbf / y}; mu_k (\ mathbf {) x}), \ sigma ^ 2 (\ mathbf {x})) \ qquad \ qquad (5)

рдХреНрдпрд╛ рдЕрдЬреАрдм рд╕реВрддреНрд░ рд╣реИ, рдЖрдк рдХрд╣рддреЗ рд╣реИрдВред рдЪрд▓рд┐рдП рдЗрд╕рдХрд╛ рдкрддрд╛ рд▓рдЧрд╛рддреЗ рд╣реИрдВред рд╣рдорд╛рд░рд╛ рдПрдордбреА рдиреЗрдЯрд╡рд░реНрдХ рдорд╛рдбрд▓ рдмрдирд╛рдирд╛ рд╕реАрдЦ рд░рд╣рд╛ рд╣реИ  рдореБ рдФрд░ рд╡рд┐рдЪрд░рдг  рд╕рд┐рдЧреНрдорд╛2 рдХрдИ рд╡рд┐рддрд░рдг рдХреЗ рд▓рд┐рдПред рд╕реВрддреНрд░ рдореЗрдВ (5)  pik( mathbfx) - рдкреНрд░рддреНрдпреЗрдХ рдмрд┐рдВрджреБ рдХреЗ рд▓рд┐рдП рдПрдХ рдЕрд▓рдЧ рд╡рд┐рддрд░рдг рдХреЗ рддрдерд╛рдХрдерд┐рдд рдорд╣рддреНрд╡ рдХрд╛рд░рдХ xi in mathbfx рдПрдХ рдирд┐рд╢реНрдЪрд┐рдд рдорд┐рд╢реНрд░рдг рдХрд╛рд░рдХ, рдпрд╛ рд╡рд┐рддрд░рдг рдореЗрдВ рд╕реЗ рдкреНрд░рддреНрдпреЗрдХ рдПрдХ рдирд┐рд╢реНрдЪрд┐рдд рдмрд┐рдВрджреБ рдкрд░ рдХрд┐рддрдирд╛ рдпреЛрдЧрджрд╛рди рджреЗрддрд╛ рд╣реИред рдХреБрд▓ рд╡рд╣рд╛рдБ K рд╡рд┐рддрд░рдгред

рдХреЗ рдмрд╛рд░реЗ рдореЗрдВ рдХреБрдЫ рдФрд░ рд╢рдмреНрдж  pik( mathbfx) - рд╡рд╛рд╕реНрддрд╡ рдореЗрдВ, рдпрд╣ рдПрдХ рд╡рд┐рддрд░рдг рднреА рд╣реИ рдФрд░ рдПрдХ рдмрд┐рдВрджреБ рдХреЗ рд▓рд┐рдП рд╕рдВрднрд╛рд╡рдирд╛ рдХрд╛ рдкреНрд░рддрд┐рдирд┐рдзрд┐рддреНрд╡ рдХрд░рддрд╛ рд╣реИ xi in mathbfx рдПрдХ рд╢рд░реНрдд рд╣реЛрдЧреА k ред

рдлреВ, рдлрд┐рд░, рдпрд╣ рдЧрдгрд┐рдд, рдЪрд▓реЛ рдкрд╣рд▓реЗ рд╕реЗ рд╣реА рдХреБрдЫ рд▓рд┐рдЦреЗрдВред рдФрд░ рдЗрд╕рд▓рд┐рдП, рдЪрд▓реЛ рдиреЗрдЯрд╡рд░реНрдХ рдХреЛ рд▓рд╛рдЧреВ рдХрд░рдирд╛ рд╢реБрд░реВ рдХрд░рддреЗ рд╣реИрдВред рд╣рдорд╛рд░реЗ рдиреЗрдЯрд╡рд░реНрдХ рдХреЗ рд▓рд┐рдП рд╣рдо рд▓реЗрддреЗ рд╣реИрдВ K=30 ред

 self.fc = nn.Linear(input_dim, layer_size) self.fc2 = nn.Linear(layer_size, 50) self.pi = nn.Linear(layer_size, coefs) self.mu = nn.Linear(layer_size, out_dim*coefs) # mean self.sigma_sq = nn.Linear(layer_size, coefs) # variance 

рд╣рдорд╛рд░реЗ рдиреЗрдЯрд╡рд░реНрдХ рдХреЗ рд▓рд┐рдП рдЖрдЙрдЯрдкреБрдЯ рд▓реЗрдпрд░ рдХреЛ рдкрд░рд┐рднрд╛рд╖рд┐рдд рдХрд░реЗрдВ:

 x = F.relu(self.fc(x)) x = F.relu(self.fc2(x)) pi = F.softmax(self.pi(x), dim=1) sigma_sq = torch.exp(self.sigma_sq(x)) mu = self.mu(x) 

рд╣рдо рддреНрд░реБрдЯрд┐ рдлрд╝рдВрдХреНрд╢рди рдпрд╛ рд╣рд╛рдирд┐ рдлрд╝рдВрдХреНрд╢рди, рд╕реВрддреНрд░ (5) рд▓рд┐рдЦрддреЗ рд╣реИрдВ:

 def gaussian_pdf(x, mu, sigma_sq): return (1/torch.sqrt(2*np.pi*sigma_sq)) * torch.exp((-1/(2*sigma_sq)) * torch.norm((x-mu), 2, 1)**2) losses = Variable(torch.zeros(y.shape[0])) # p(y|x) for i in range(COEFS): likelihood = gaussian_pdf(y, mu[:, i*OUT_DIM:(i+1)*OUT_DIM], sigma_sq[:, i]) prior = pi[:, i] losses += prior * likelihood loss = torch.mean(-torch.log(losses)) 

рдПрдордбреАрдПрди рдмрд┐рд▓реНрдб рдХреЛрдб рдХреЛ рдкреВрд░рд╛ рдХрд░реЗрдВ
 COEFS = 30 class MDN(nn.Module): def __init__(self, input_dim=IN_DIM, out_dim=OUT_DIM, layer_size=50, coefs=COEFS): super(MDN, self).__init__() self.fc = nn.Linear(input_dim, layer_size) self.fc2 = nn.Linear(layer_size, 50) self.pi = nn.Linear(layer_size, coefs) self.mu = nn.Linear(layer_size, out_dim*coefs) # mean self.sigma_sq = nn.Linear(layer_size, coefs) # variance self.out_dim = out_dim self.coefs = coefs def forward(self, x): x = F.relu(self.fc(x)) x = F.relu(self.fc2(x)) pi = F.softmax(self.pi(x), dim=1) sigma_sq = torch.exp(self.sigma_sq(x)) mu = self.mu(x) return pi, mu, sigma_sq #       def gaussian_pdf(x, mu, sigma_sq): return (1/torch.sqrt(2*np.pi*sigma_sq)) * torch.exp((-1/(2*sigma_sq)) * torch.norm((x-mu), 2, 1)**2) #   def loss_fn(y, pi, mu, sigma_sq): losses = Variable(torch.zeros(y.shape[0])) # p(y|x) for i in range(COEFS): likelihood = gaussian_pdf(y, mu[:, i*OUT_DIM:(i+1)*OUT_DIM], sigma_sq[:, i]) prior = pi[:, i] losses += prior * likelihood loss = torch.mean(-torch.log(losses)) return loss 


рд╣рдорд╛рд░рд╛ рдПрдордбреА рдиреЗрдЯрд╡рд░реНрдХ рдЬрд╛рдиреЗ рдХреЗ рд▓рд┐рдП рддреИрдпрд╛рд░ рд╣реИред рд▓рдЧрднрдЧ рддреИрдпрд╛рд░ред рдпрд╣ рдЙрд╕реЗ рдкреНрд░рд╢рд┐рдХреНрд╖рд┐рдд рдХрд░рдиреЗ рдФрд░ рдкрд░рд┐рдгрд╛рдореЛрдВ рдХреЛ рджреЗрдЦрдиреЗ рдХреЗ рд▓рд┐рдП рдмрдиреА рд╣реБрдИ рд╣реИред

рдПрдордбреАрдПрди рдкреНрд░рд╢рд┐рдХреНрд╖рдг
 def train_mdn(net, x_train, y_train, x_test, y_test, epoches=1000): optimizer = optim.Adam(net.parameters(), lr=0.01) N_EPOCHES = epoches BS = 1500 n_batches = int(np.ceil(x_train.shape[0] / BS)) train_losses = [] test_losses = [] for i in range(N_EPOCHES): for bi in range(n_batches): x_batch, y_batch = fetch_batch(x_train, y_train, bi, BS) x_train_var = Variable(torch.from_numpy(x_batch)) y_train_var = Variable(torch.from_numpy(y_batch)) optimizer.zero_grad() pi, mu, sigma_sq = net(x_train_var) loss = loss_fn(y_train_var, pi, mu, sigma_sq) loss.backward() optimizer.step() with torch.no_grad(): if i%10 == 0: x_test_var = Variable(torch.from_numpy(x_test)) y_test_var = Variable(torch.from_numpy(y_test)) pi, mu, sigma_sq = net(x_test_var) test_loss = loss_fn(y_test_var, pi, mu, sigma_sq) train_losses.append(loss.item()) test_losses.append(test_loss.item()) sys.stdout.write('\r Iter: %d, test loss: %.5f, train loss: %.5f' %(i, test_loss.item(), loss.item())) sys.stdout.flush() return train_losses, test_losses mdn_net = MDN() mdn_train_losses, mdn_test_losses = train_mdn(mdn_net, x_train_inv, y_train_inv, x_test_inv, y_test_inv) 



рдкреНрд░рд╢рд┐рдХреНрд╖рдг рдХреА рдкреБрдирд░рд╛рд╡реГрддреНрддрд┐ рдХреЗ рдЖрдзрд╛рд░ рдкрд░ рд╣рд╛рдирд┐ рдлрд╝рдВрдХреНрд╢рди рдореВрд▓реНрдпреЛрдВ рдХрд╛ рдЧреНрд░рд╛рдл, рдкреНрд░рд╢рд┐рдХреНрд╖рдг рдбреЗрдЯрд╛ рдФрд░ рдкрд░реАрдХреНрд╖рдг рдбреЗрдЯрд╛ рдХреЗ рд▓рд┐рдП рдореВрд▓реНрдпреЛрдВ рдХрд╛ рдЧреНрд░рд╛рдлред

рдЪреВрдВрдХрд┐ рд╣рдорд╛рд░реЗ рдиреЗрдЯрд╡рд░реНрдХ рдиреЗ рдХрдИ рд╡рд┐рддрд░рдгреЛрдВ рдХреЗ рд▓рд┐рдП рдорд╛рдзреНрдп рдорд╛рди рд╕реАрдЦреЗ рд╣реИрдВ, рддреЛ рдЖрдЗрдП рдЗрд╕реЗ рджреЗрдЦреЗрдВ:

 pi, mu, sigma_sq = mdn_net(Variable(torch.from_numpy(x_test_inv))) 


рдкреНрд░рддреНрдпреЗрдХ рдмрд┐рдВрджреБ (рдмрд╛рдПрдВ) рдХреЗ рд▓рд┐рдП рджреЛ рд╕рдмрд╕реЗ рдЕрдзрд┐рдХ рд╕рдВрднрд╛рд╡рд┐рдд рдорд╛рдзреНрдп рдорд╛рдиреЛрдВ рдХреЗ рд▓рд┐рдП рдЧреНрд░рд╛рдлрд╝ред рдкреНрд░рддреНрдпреЗрдХ рдмрд┐рдВрджреБ (рджрд╛рдПрдВ) рдХреЗ рд▓рд┐рдП 4 рд╕рдмрд╕реЗ рдЕрдзрд┐рдХ рд╕рдВрднрд╛рд╡рд┐рдд рдорд╛рдзреНрдп рдорд╛рдиреЛрдВ рдХреЗ рд▓рд┐рдП рдЧреНрд░рд╛рдлрд╝ред


рдкреНрд░рддреНрдпреЗрдХ рдмрд┐рдВрджреБ рдХреЗ рд▓рд┐рдП рд╕рднреА рдорд╛рдзреНрдп рдорд╛рдиреЛрдВ рдХреЗ рд▓рд┐рдП рдЧреНрд░рд╛рдлрд╝ред

рдбреЗрдЯрд╛ рдХреА рднрд╡рд┐рд╖реНрдпрд╡рд╛рдгреА рдХрд░рдиреЗ рдХреЗ рд▓рд┐рдП, рд╣рдо рдмреЗрддрд░рддреАрдм рдврдВрдЧ рд╕реЗ рдХрдИ рдорд╛рдиреЛрдВ рдХрд╛ рдЪрдпрди рдХрд░реЗрдВрдЧреЗ  рдореБ рдФрд░  рд╕рд┐рдЧреНрдорд╛2 рдореВрд▓реНрдп рдХреЗ рдЖрдзрд╛рд░ рдкрд░  pik( mathbfx) ред рдФрд░ рдлрд┐рд░ рд▓рдХреНрд╖реНрдп рдбреЗрдЯрд╛ рдЙрддреНрдкрдиреНрди рдХрд░рдиреЗ рдХреЗ рд▓рд┐рдП рдЙрдирдХреЗ рдЖрдзрд╛рд░ рдкрд░  y рд╕рд╛рдорд╛рдиреНрдп рд╡рд┐рддрд░рдг рдХрд╛ рдЙрдкрдпреЛрдЧ рдХрд░рдирд╛ред

рдкрд░рд┐рдгрд╛рдо рдХреА рднрд╡рд┐рд╖реНрдпрд╡рд╛рдгреА
 def rand_n_sample_cumulative(pi, mu, sigmasq, samples=10): n = pi.shape[0] out = Variable(torch.zeros(n, samples, OUT_DIM)) for i in range(n): for j in range(samples): u = np.random.uniform() prob_sum = 0 for k in range(COEFS): prob_sum += pi.data[i, k] if u < prob_sum: for od in range(OUT_DIM): sample = np.random.normal(mu.data[i, k*OUT_DIM+od], np.sqrt(sigmasq.data[i, k])) out[i, j, od] = sample break return out pi, mu, sigma_sq = mdn_net(Variable(torch.from_numpy(x_test_inv))) preds = rand_n_sample_cumulative(pi, mu, sigma_sq, samples=10) 


10 рдмреЗрддрд░рддреАрдм рдврдВрдЧ рд╕реЗ рдЪрдпрдирд┐рдд рдореВрд▓реНрдпреЛрдВ рдХреЗ рд▓рд┐рдП рдЕрдиреБрдорд╛рдирд┐рдд рдбреЗрдЯрд╛  рдореБ рдФрд░  рд╕рд┐рдЧреНрдорд╛2 (рдмрд╛рдПрдВ) рдФрд░ рджреЛ (рджрд╛рдПрдВ) рдХреЗ рд▓рд┐рдПред

рдпрд╣ рдЙрди рдЖрдВрдХрдбрд╝реЛрдВ рд╕реЗ рджреЗрдЦрд╛ рдЬрд╛ рд╕рдХрддрд╛ рд╣реИ рдХрд┐ рдПрдордбреАрдПрди рдиреЗ "рдЙрд▓рдЯрд╛" рдХрд╛рд░реНрдп рдХреЗ рд╕рд╛рде рдПрдХ рдЙрддреНрдХреГрд╖реНрдЯ рдХрд╛рд░реНрдп рдХрд┐рдпрд╛ред

рдЕрдзрд┐рдХ рдЬрдЯрд┐рд▓ рдбреЗрдЯрд╛ рдХрд╛ рдЙрдкрдпреЛрдЧ рдХрд░рдирд╛


рдЖрдЗрдП рджреЗрдЦреЗрдВ рдХрд┐ рд╣рдорд╛рд░рд╛ рдПрдордбреА рдиреЗрдЯрд╡рд░реНрдХ рд╕рд░реНрдкрд┐рд▓ рдбреЗрдЯрд╛ рдЬреИрд╕реЗ рдЕрдзрд┐рдХ рдЬрдЯрд┐рд▓ рдбреЗрдЯрд╛ рдХреЛ рдХреИрд╕реЗ рд╕рдВрднрд╛рд▓рддрд╛ рд╣реИред рдХрд╛рд░реНрддреАрдп рдирд┐рд░реНрджреЗрд╢рд╛рдВрдХ рдореЗрдВ рд╣рд╛рдЗрдкрд░рдмреЛрд▓рд┐рдХ рд╕рд░реНрдкрд┐рд▓ рдХрд╛ рд╕рдореАрдХрд░рдг:

x= rho cos phi qquad qquad qquad qquad qquad qquad(6)y= rho sin phi

рд╕рд░реНрдкрд┐рд▓ рдбреЗрдЯрд╛ рдЬрдирд░реЗрд╢рди
 N = 2000 x_train_compl = [] y_train_compl = [] x_test_compl = [] y_test_compl = [] noise_train = np.random.uniform(-1, 1, (N, IN_DIM)).astype(np.float32) noise_test = np.random.uniform(-1, 1, (N, IN_DIM)).astype(np.float32) for i, theta in enumerate(np.linspace(0, 5*np.pi, N).astype(np.float32)): #  6 r = ((theta)) x_train_compl.append(r*np.cos(theta) + noise_train[i]) y_train_compl.append(r*np.sin(theta)) x_test_compl.append(r*np.cos(theta) + noise_test[i]) y_test_compl.append(r*np.sin(theta)) x_train_compl = np.array(x_train_compl).reshape((-1, 1)) y_train_compl = np.array(y_train_compl).reshape((-1, 1)) x_test_compl = np.array(x_test_compl).reshape((-1, 1)) y_test_compl = np.array(y_test_compl).reshape((-1, 1)) 



рд╕рд░реНрдкрд┐рд▓ рдбреЗрдЯрд╛ рдХрд╛ рдЧреНрд░рд╛рдлред

рдордЬрд╝реЗ рдХреЗ рд▓рд┐рдП, рдЖрдЗрдП рджреЗрдЦреЗрдВ рдХрд┐ рдПрдХ рдирд┐рдпрдорд┐рдд рдлрд╝реАрдб-рдлреЙрд░рд╡рд░реНрдб рдиреЗрдЯрд╡рд░реНрдХ рдЗрд╕ рддрд░рд╣ рдХреЗ рдХрд╛рд░реНрдп рд╕реЗ рдХреИрд╕реЗ рдирд┐рдкрдЯреЗрдЧрд╛ред


рдЬреИрд╕реА рдХрд┐ рдЙрдореНрдореАрдж рдереА, рдлреАрдб-рдлреЙрд░рд╡рд░реНрдб рдиреЗрдЯрд╡рд░реНрдХ рдРрд╕реЗ рдбреЗрдЯрд╛ рдХреЗ рд▓рд┐рдП рдкреНрд░рддрд┐рдЧрдорди рд╕рдорд╕реНрдпрд╛ рдХреЛ рд╣рд▓ рдХрд░рдиреЗ рдореЗрдВ рд╕рдХреНрд╖рдо рдирд╣реАрдВ рд╣реИред

рд╣рдо рд╕рд░реНрдкрд┐рд▓ рдбреЗрдЯрд╛ рдкрд░ рдкреНрд░рд╢рд┐рдХреНрд╖рдг рдХреЗ рд▓рд┐рдП рдкрд╣рд▓реЗ рд╕реЗ рд╡рд░реНрдгрд┐рдд рдФрд░ рдирд┐рд░реНрдорд┐рдд рдПрдордбреА рдиреЗрдЯрд╡рд░реНрдХ рдХрд╛ рдЙрдкрдпреЛрдЧ рдХрд░рддреЗ рд╣реИрдВред


рдорд┐рдХреНрд╕рдЪрд░ рдбреЗрдВрд╕рд┐рдЯреА рдиреЗрдЯрд╡рд░реНрдХ рдиреЗ рдЗрд╕ рд╕реНрдерд┐рддрд┐ рдореЗрдВ рдмрд╣реБрдд рдЕрдЪреНрдЫрд╛ рдХрд╛рдо рдХрд┐рдпрд╛ред

рдирд┐рд╖реНрдХрд░реНрд╖


рдЗрд╕ рд▓реЗрдЦ рдХреА рд╢реБрд░реБрдЖрдд рдореЗрдВ, рд╣рдордиреЗ рд░реИрдЦрд┐рдХ рдкреНрд░рддрд┐рдЧрдорди рдХреА рдореВрд▓ рдмрд╛рддреЗрдВ рдпрд╛рдж рдХреАрдВред рд╣рдордиреЗ рджреЗрдЦрд╛ рдХрд┐ рд╕рд╛рдорд╛рдиреНрдп рд╡рд┐рддрд░рдг рдФрд░ рдПрдордПрд╕рдИ рдХреЗ рд▓рд┐рдП рдФрд╕рдд рдЦреЛрдЬрдиреЗ рдХреЗ рдмреАрдЪред рдЦрд╛рд░рд┐рдЬ рдХрд░ рджрд┐рдпрд╛ рдХрд┐ рдПрдирдПрд▓рдПрд▓ рдФрд░ рдХреНрд░реЙрд╕ рдПрдиреНрдЯреНрд░реЙрдкреА рдХреИрд╕реЗ рдЬреБрдбрд╝реЗред рдФрд░ рд╕рдмрд╕реЗ рдорд╣рддреНрд╡рдкреВрд░реНрдг рдмрд╛рдд, рд╣рдордиреЗ рдПрдордбреАрдПрди рдореЙрдбрд▓ рдХрд╛ рдкрддрд╛ рд▓рдЧрд╛рдпрд╛, рдЬреЛ рдорд┐рд╢реНрд░рд┐рдд рд╡рд┐рддрд░рдг рд╕реЗ рдкреНрд░рд╛рдкреНрдд рдЖрдВрдХрдбрд╝реЛрдВ рд╕реЗ рд╕реАрдЦрдиреЗ рдореЗрдВ рд╕рдХреНрд╖рдо рд╣реИред рдореБрдЭреЗ рдЙрдореНрдореАрдж рд╣реИ рдХрд┐ рд▓реЗрдЦ рд╕рдордЭ рдореЗрдВ рдЖрддрд╛ рд╣реИ рдФрд░ рджрд┐рд▓рдЪрд╕реНрдк рд╣реИ, рдЗрд╕ рддрдереНрдп рдХреЗ рдмрд╛рд╡рдЬреВрдж рдХрд┐ рдЧрдгрд┐рдд рдХрд╛ рдПрдХ рд╕рд╛ рдерд╛ред

рдкреВрд░рд╛ рдХреЛрдб GitHub рдкрд░ рджреЗрдЦрд╛ рдЬрд╛ рд╕рдХрддрд╛ рд╣реИ ред


рд╕рд╛рд╣рд┐рддреНрдп


  1. рдорд┐рдХреНрд╕рдЪрд░ рдбреЗрдВрд╕рд┐рдЯреА рдиреЗрдЯрд╡рд░реНрдХреНрд╕ (рдХреНрд░рд┐рд╕реНрдЯреЛрдлрд░ рдПрдоред рдмрд┐рд╢рдк, рдиреНрдпреВрд░рд▓ рдХрдВрдкреНрдпреВрдЯрд┐рдВрдЧ рд░рд┐рд╕рд░реНрдЪ рдЧреНрд░реБрдк, рдХрдВрдкреНрдпреВрдЯрд░ рд╕рд╛рдЗрдВрд╕ рдПрдВрдб рдПрдкреНрд▓рд╛рдЗрдб рдореИрдердореЗрдЯрд┐рдХреНрд╕, рдПрд╕реНрдЯрди рдпреВрдирд┐рд╡рд░реНрд╕рд┐рдЯреА, рдмрд░реНрдорд┐рдВрдШрдо) рд╡рд┐рднрд╛рдЧ - рд▓реЗрдЦ рдПрдордбреА рдиреЗрдЯрд╡рд░реНрдХ рдХреЗ рд╕рд┐рджреНрдзрд╛рдВрдд рдХрд╛ рдкреВрд░реА рддрд░рд╣ рд╕реЗ рд╡рд░реНрдгрди рдХрд░рддрд╛ рд╣реИред
  2. рдХрдо рд╕реЗ рдХрдо рд╡рд░реНрдЧ рдФрд░ рдЕрдзрд┐рдХрддрдо рд╕рдВрднрд╛рд╡рдирд╛ (MROsborne)

Source: https://habr.com/ru/post/hi433804/


All Articles