рдпрджрд┐ рдЖрдк рдорд╢реАрди рд╕реАрдЦрдиреЗ рдореЗрдВ рд░реБрдЪрд┐ рд░рдЦрддреЗ рд╣реИрдВ, рддреЛ рдЖрдкрдиреЗ рд╢рд╛рдпрдж BERT рдФрд░ рдЯреНрд░рд╛рдВрд╕рдлрд╛рд░реНрдорд░ рдХреЗ рдмрд╛рд░реЗ рдореЗрдВ рд╕реБрдирд╛ рд╣реЛрдЧрд╛ред
BERT Google рдХрд╛ рдПрдХ рднрд╛рд╖рд╛ рдореЙрдбрд▓ рд╣реИ, рдЬреЛ рдХрдИ рдХрд╛рд░реНрдпреЛрдВ рдкрд░ рдПрдХ рд╡реНрдпрд╛рдкрдХ рдорд╛рд░реНрдЬрд┐рди рджреНрд╡рд╛рд░рд╛ рдЕрддреНрдпрд╛рдзреБрдирд┐рдХ рдкрд░рд┐рдгрд╛рдо рджрд┐рдЦрд╛рддрд╛ рд╣реИред рдмреАрдИрдЖрд░рдЯреА, рдФрд░ рдЖрдорддреМрд░ рдкрд░ рдЯреНрд░рд╛рдВрд╕рдлрд╛рд░реНрдорд░, рдкреНрд░рд╛рдХреГрддрд┐рдХ рднрд╛рд╖рд╛ рдкреНрд░рд╕рдВрд╕реНрдХрд░рдг рдПрд▓реНрдЧреЛрд░рд┐рджрдо (рдПрдирдПрд▓рдкреА) рдХреЗ рд╡рд┐рдХрд╛рд╕ рдореЗрдВ рдПрдХ рдкреВрд░реА рддрд░рд╣ рд╕реЗ рдирдпрд╛ рдХрджрдо рдмрди рдЧрдП рд╣реИрдВред рдЙрдирдХреЗ рдмрд╛рд░реЗ рдореЗрдВ рд▓реЗрдЦ рдФрд░ рд╡рд┐рднрд┐рдиреНрди рдмреЗрдВрдЪрдорд╛рд░реНрдХ рдХреЗ рд▓рд┐рдП "рд╕реНрдЯреИрдВрдбрд┐рдВрдЧ" рдХреЛ Papers With Code рд╡реЗрдмрд╕рд╛рдЗрдЯ рдкрд░ рдкрд╛рдпрд╛ рдЬрд╛ рд╕рдХрддрд╛ рд╣реИред
рдмреАрдИрдЖрд░рдЯреА рдХреЗ рд╕рд╛рде рдПрдХ рд╕рдорд╕реНрдпрд╛ рд╣реИ: рдпрд╣ рдФрджреНрдпреЛрдЧрд┐рдХ рдкреНрд░рдгрд╛рд▓рд┐рдпреЛрдВ рдореЗрдВ рдЙрдкрдпреЛрдЧ рдХрд░рдиреЗ рдХреЗ рд▓рд┐рдП рд╕рдорд╕реНрдпрд╛рдЧреНрд░рд╕реНрдд рд╣реИред рдмреАрдИрдЖрд░рдЯреА-рдмреЗрд╕ рдореЗрдВ 110 рдПрдо рдкреИрд░рд╛рдореАрдЯрд░, рдмреАрдИрдЖрд░рдЯреА-рдмрдбрд╝реЗ - 340 рдПрдо рд╢рд╛рдорд┐рд▓ рд╣реИрдВред рдЗрддрдиреА рдмрдбрд╝реА рд╕рдВрдЦреНрдпрд╛ рдореЗрдВ рдорд╛рдкрджрдВрдбреЛрдВ рдХреЗ рдХрд╛рд░рдг, рдЗрд╕ рдореЙрдбрд▓ рдХреЛ рд╕реАрдорд┐рдд рд╕рдВрд╕рд╛рдзрдиреЛрдВ рд╡рд╛рд▓реЗ рдЙрдкрдХрд░рдгреЛрдВ рдЬреИрд╕реЗ рдореЛрдмрд╛рдЗрд▓ рдлреЛрди рдкрд░ рдбрд╛рдЙрдирд▓реЛрдб рдХрд░рдирд╛ рдореБрд╢реНрдХрд┐рд▓ рд╣реИред рдЗрд╕рдХреЗ рдЕрд▓рд╛рд╡рд╛, рд▓рдВрдмреЗ рд╕рдордп рддрдХ рдкреНрд░рд╡реЗрд╢ рдЗрд╕ рдореЙрдбрд▓ рдХреЛ рдЕрдиреБрдкрдпреБрдХреНрдд рдмрдирд╛рддрд╛ рд╣реИ рдЬрд╣рд╛рдВ рдкреНрд░рддрд┐рдХреНрд░рд┐рдпрд╛ рдХреА рдЧрддрд┐ рдорд╣рддреНрд╡рдкреВрд░реНрдг рд╣реЛрддреА рд╣реИред рдЗрд╕рд▓рд┐рдП, BERT рдХреЛ рддреЗрдЬ рдХрд░рдиреЗ рдХреЗ рддрд░реАрдХреЗ рдЦреЛрдЬрдирд╛ рдПрдХ рдмрд╣реБрдд рд╣реА рдЧрд░реНрдо рд╡рд┐рд╖рдп рд╣реИред
рд╣рдо Avito рдореЗрдВ рдЕрдХреНрд╕рд░ рдкрд╛рда рд╡рд░реНрдЧреАрдХрд░рдг рд╕рдорд╕реНрдпрд╛рдУрдВ рдХреЛ рд╣рд▓ рдХрд░рдиреЗ рдХреЗ рд▓рд┐рдП рд╣реИред рдпрд╣ рдПрдХ рд╕рд╛рдорд╛рдиреНрдп рд░реВрдк рд╕реЗ рд▓рд╛рдЧреВ рдорд╢реАрди рд▓рд░реНрдирд┐рдВрдЧ рдХрд╛рд░реНрдп рд╣реИ рдЬрд┐рд╕рдХрд╛ рдЕрдЪреНрдЫреА рддрд░рд╣ рд╕реЗ рдЕрдзреНрдпрдпрди рдХрд┐рдпрд╛ рдЧрдпрд╛ рд╣реИред рд▓реЗрдХрд┐рди рд╣рдореЗрд╢рд╛ рдХреБрдЫ рдирдпрд╛ рдХрд░рдиреЗ рдХреА рдХреЛрд╢рд┐рд╢ рдХрд░рдиреЗ рдХрд╛ рдкреНрд░рд▓реЛрднрди рд╣реЛрддрд╛ рд╣реИред рдпрд╣ рд▓реЗрдЦ рд░реЛрдЬрдорд░реНрд░рд╛ рдХреА рдорд╢реАрди рд╕реАрдЦрдиреЗ рдХреЗ рдХрд╛рд░реНрдпреЛрдВ рдореЗрдВ BERT рдХреЛ рд▓рд╛рдЧреВ рдХрд░рдиреЗ рдХреЗ рдкреНрд░рдпрд╛рд╕ рд╕реЗ рдкреИрджрд╛ рд╣реБрдЖ рдерд╛ред рдЗрд╕рдореЗрдВ, рдореИрдВ рджрд┐рдЦрд╛рдКрдВрдЧрд╛ рдХрд┐ рдХреИрд╕реЗ рдЖрдк рдирдП рдбреЗрдЯрд╛ рдХреЛ рдЬреЛрдбрд╝рдиреЗ рдФрд░ рдореЙрдбрд▓ рдХреЛ рдЬрдЯрд┐рд▓ рдХрд┐рдП рдмрд┐рдирд╛ BERT рдХрд╛ рдЙрдкрдпреЛрдЧ рдХрд░рдХреЗ рдореМрдЬреВрджрд╛ рдореЙрдбрд▓ рдХреА рдЧреБрдгрд╡рддреНрддрд╛ рдореЗрдВ рдХрд╛рдлреА рд╕реБрдзрд╛рд░ рдХрд░ рд╕рдХрддреЗ рд╣реИрдВред

рддрдВрддреНрд░рд┐рдХрд╛ рдиреЗрдЯрд╡рд░реНрдХ рдХреЛ рддреЗрдЬ рдХрд░рдиреЗ рдХреА рдПрдХ рд╡рд┐рдзрд┐ рдХреЗ рд░реВрдк рдореЗрдВ рдЬреНрдЮрд╛рди рдЖрд╕рд╡рди
рддрдВрддреНрд░рд┐рдХрд╛ рдиреЗрдЯрд╡рд░реНрдХ рдХреЛ рддреЗрдЬ / рд╣рд▓реНрдХрд╛ рдХрд░рдиреЗ рдХреЗ рдХрдИ рддрд░реАрдХреЗ рд╣реИрдВред рд╕рдмрд╕реЗ рд╡рд┐рд╕реНрддреГрдд рд╕рдореАрдХреНрд╖рд╛ рдЬреЛ рдореБрдЭреЗ рдорд┐рд▓реА рд╣реИ рд╡рд╣ рдордзреНрдпрдо рдкрд░ рдЗрдВрдЯреЗрдВрдЯреЛ рдмреНрд▓реЙрдЧ рдкрд░ рдкреНрд░рдХрд╛рд╢рд┐рдд рд╣реБрдИ рд╣реИред
рддрд░реАрдХреЛрдВ рдХреЛ рдореЛрдЯреЗ рддреМрд░ рдкрд░ рддреАрди рд╕рдореВрд╣реЛрдВ рдореЗрдВ рд╡рд┐рднрд╛рдЬрд┐рдд рдХрд┐рдпрд╛ рдЬрд╛ рд╕рдХрддрд╛ рд╣реИ:
- рдиреЗрдЯрд╡рд░реНрдХ рдЖрд░реНрдХрд┐рдЯреЗрдХреНрдЪрд░ рдмрджрд▓ рдЬрд╛рддрд╛ рд╣реИред
- рдореЙрдбрд▓ рд╕рдВрдкреАрдбрд╝рди (рдорд╛рддреНрд░рд╛ рдХрд╛ рдард╣рд░рд╛рд╡, рдЫрдВрдЯрд╛рдИ)ред
- рдЬреНрдЮрд╛рди рдЖрд╕рд╡рдиред
рдпрджрд┐ рдкрд╣рд▓реЗ рджреЛ рддрд░реАрдХреЗ рдЕрдкреЗрдХреНрд╖рд╛рдХреГрдд рдкреНрд░рд╕рд┐рджреНрдз рдФрд░ рд╕рдордЭрдиреЗ рдпреЛрдЧреНрдп рд╣реИрдВ, рддреЛ рддреАрд╕рд░рд╛ рдХрдо рдЖрдо рд╣реИред рдкрд╣рд▓реА рдмрд╛рд░, рдЖрд╕рд╡рди рдХрд╛ рд╡рд┐рдЪрд╛рд░ рд░рд┐рдЪ рдХрд╛рд░реБрдЖрдирд╛ рджреНрд╡рд╛рд░рд╛ "рдореЙрдбрд▓ рд╕рдВрдкреАрдбрд╝рди" рд▓реЗрдЦ рдореЗрдВ рдкреНрд░рд╕реНрддрд╛рд╡рд┐рдд рдХрд┐рдпрд╛ рдЧрдпрд╛ рдерд╛ред рдЗрд╕рдХрд╛ рд╕рд╛рд░ рд╕рд░рд▓ рд╣реИ: рдЖрдк рдПрдХ рд╣рд▓реНрдХреЗ рдореЙрдбрд▓ рдХреЛ рдкреНрд░рд╢рд┐рдХреНрд╖рд┐рдд рдХрд░ рд╕рдХрддреЗ рд╣реИрдВ рдЬреЛ рдПрдХ рд╢рд┐рдХреНрд╖рдХ рдореЙрдбрд▓ рдХреЗ рд╡реНрдпрд╡рд╣рд╛рд░ рдпрд╛ рдпрд╣рд╛рдВ рддрдХ тАЛтАЛрдХрд┐ рдореЙрдбрд▓ рдХреА рдПрдХ рдЯреБрдХрдбрд╝реА рдХреА рдирдХрд▓ рдХрд░реЗрдЧрд╛ред рд╣рдорд╛рд░реЗ рдорд╛рдорд▓реЗ рдореЗрдВ, рд╢рд┐рдХреНрд╖рдХ рдмреАрдИрдЖрд░рдЯреА рд╣реЛрдЧрд╛, рдФрд░ рдЫрд╛рддреНрд░ рдХреЛрдИ рднреА рдкреНрд░рдХрд╛рд╢ рдореЙрдбрд▓ рд╣реЛрдЧрд╛ред
рдХрд╛рд░реНрдп
рдЖрдЗрдП рдПрдХ рдЙрджрд╛рд╣рд░рдг рдХреЗ рд░реВрдк рдореЗрдВ рджреНрд╡рд┐рдЖрдзрд╛рд░реА рд╡рд░реНрдЧреАрдХрд░рдг рдХрд╛ рдЙрдкрдпреЛрдЧ рдХрд░рдХреЗ рдЖрд╕рд╡рди рдХрд╛ рд╡рд┐рд╢реНрд▓реЗрд╖рдг рдХрд░реЗрдВред рдПрдирдПрд▓рдкреА рдХреЗ рд▓рд┐рдП рдореЙрдбрд▓ рдХрд╛ рдкрд░реАрдХреНрд╖рдг рдХрд░рдиреЗ рд╡рд╛рд▓реЗ рдХрд╛рд░реНрдпреЛрдВ рдХреЗ рдорд╛рдирдХ рд╕реЗрдЯ рд╕реЗ рдУрдкрди рдПрд╕рдПрд╕рдЯреА -2 рдбреЗрдЯрд╛рд╕реЗрдЯ рд▓реЗрдВред
рдпрд╣ рдбреЗрдЯрд╛рд╕реЗрдЯ рднрд╛рд╡рдирд╛рддреНрдордХ рд░рдВрдЧ - рд╕рдХрд╛рд░рд╛рддреНрдордХ рдпрд╛ рдирдХрд╛рд░рд╛рддреНрдордХ рджреНрд╡рд╛рд░рд╛ рдЯреВрдЯреА рд╣реБрдИ IMDb рдХреЗ рд╕рд╛рде рдлрд┐рд▓реНрдореЛрдВ рдХреА рд╕рдореАрдХреНрд╖рд╛рдУрдВ рдХрд╛ рдПрдХ рд╕рдВрдЧреНрд░рд╣ рд╣реИред рдЗрд╕ рдбреЗрдЯрд╛рд╕реЗрдЯ рдкрд░ рдореАрдЯреНрд░рд┐рдХ рд╕рдЯреАрдХрддрд╛ рд╣реИред
рдкреНрд░рд╢рд┐рдХреНрд╖рдг BERT- рдЖрдзрд╛рд░рд┐рдд рдореЙрдбрд▓ рдпрд╛ "рд╢рд┐рдХреНрд╖рдХ"
рд╕рдмрд╕реЗ рдкрд╣рд▓реЗ, рдЖрдкрдХреЛ "рдмрдбрд╝реЗ" BERT- рдЖрдзрд╛рд░рд┐рдд рдореЙрдбрд▓ рдХреЛ рдкреНрд░рд╢рд┐рдХреНрд╖рд┐рдд рдХрд░рдиреЗ рдХреА рдЖрд╡рд╢реНрдпрдХрддрд╛ рд╣реИ, рдЬреЛ рдПрдХ рд╢рд┐рдХреНрд╖рдХ рдмрди рдЬрд╛рдПрдЧрд╛ред рдРрд╕рд╛ рдХрд░рдиреЗ рдХрд╛ рд╕рдмрд╕реЗ рдЖрд╕рд╛рди рддрд░реАрдХрд╛ рд╣реИ рдХрд┐ BERT рд╕реЗ рдПрдореНрдмреЗрдбрд┐рдВрдЧ рд▓реЗрдирд╛ рдФрд░ рдЙрдирдХреЗ рдКрдкрд░ рдХреНрд▓рд╛рд╕рд┐рдлрд╛рдпрд░ рдХреЛ рдкреНрд░рд╢рд┐рдХреНрд╖рд┐рдд рдХрд░рдирд╛, рдПрдХ рдкрд░рдд рдХреЛ рдиреЗрдЯрд╡рд░реНрдХ рдореЗрдВ рдЬреЛрдбрд╝рдирд╛ред
рдЯреНрд░рд╛рдВрд╕рдлреЙрд░реНрдорд░ рд▓рд╛рдЗрдмреНрд░реЗрд░реА рдХреЗ рд▓рд┐рдП рдзрдиреНрдпрд╡рд╛рдж , рдпрд╣ рдХрд░рдирд╛ рдмрд╣реБрдд рдЖрд╕рд╛рди рд╣реИ, рдХреНрдпреЛрдВрдХрд┐ рдмрд░реНрдЯрдлреЙрд░рд╕реЗрдВрд╕реЗрдВрд╕ рдХреНрд▓реИрд╕рд┐рдлрд┐рдХреЗрд╢рди рдореЙрдбрд▓ рдХреЗ рд▓рд┐рдП рдПрдХ рддреИрдпрд╛рд░ рд╡рд░реНрдЧ рд╣реИред рдореЗрд░реА рд░рд╛рдп рдореЗрдВ, рдЗрд╕ рдореЙрдбрд▓ рдХреЛ рдкрдврд╝рд╛рдиреЗ рдХреЗ рд▓рд┐рдП рд╕рдмрд╕реЗ рд╡рд┐рд╕реНрддреГрдд рдФрд░ рд╕рдордЭрдиреЗ рдпреЛрдЧреНрдп рдЯреНрдпреВрдЯреЛрд░рд┐рдпрд▓ рдЯреБрд╡рд░реНрдбреНрд╕ рдбреЗрдЯрд╛ рд╕рд╛рдЗрдВрд╕ рджреНрд╡рд╛рд░рд╛ рдкреНрд░рдХрд╛рд╢рд┐рдд рдХрд┐рдпрд╛ рдЧрдпрд╛ рдерд╛ред
рдЖрдЗрдП рдХрд▓реНрдкрдирд╛ рдХрд░реЗрдВ рдХрд┐ рд╣рдореЗрдВ рдПрдХ рдкреНрд░рд╢рд┐рдХреНрд╖рд┐рдд BertForSequenceClassification рдореЙрдбрд▓ рдорд┐рд▓рд╛ред рд╣рдорд╛рд░реЗ рдорд╛рдорд▓реЗ рдореЗрдВ, num_labels = 2, рдХреНрдпреЛрдВрдХрд┐ рд╣рдорд╛рд░реЗ рдкрд╛рд╕ рдПрдХ рджреНрд╡рд┐рдЖрдзрд╛рд░реА рд╡рд░реНрдЧреАрдХрд░рдг рд╣реИред рд╣рдо рдЗрд╕ рдореЙрдбрд▓ рдХрд╛ рдЙрдкрдпреЛрдЧ "рд╢рд┐рдХреНрд╖рдХ" рдХреЗ рд░реВрдк рдореЗрдВ рдХрд░реЗрдВрдЧреЗред
"рдЫрд╛рддреНрд░" рд╕реАрдЦрдирд╛
рдЖрдк рдПрдХ рдЫрд╛рддреНрд░ рдХреЗ рд░реВрдк рдореЗрдВ рдХрд┐рд╕реА рднреА рд╡рд╛рд╕реНрддреБрдХрд▓рд╛ рдХреЛ рд▓реЗ рд╕рдХрддреЗ рд╣реИрдВ: рдПрдХ рддрдВрддреНрд░рд┐рдХрд╛ рдиреЗрдЯрд╡рд░реНрдХ, рдПрдХ рд░реИрдЦрд┐рдХ рдореЙрдбрд▓, рдПрдХ рдирд┐рд░реНрдгрдп рд╡реГрдХреНрд╖ред рдЖрдЗрдП рдмреЗрд╣рддрд░ рджреГрд╢реНрдп рдХреЗ рд▓рд┐рдП BiLSTM рдХреЛ рд╕рд┐рдЦрд╛рдиреЗ рдХрд╛ рдкреНрд░рдпрд╛рд╕ рдХрд░реЗрдВред рд╢реБрд░реВ рдХрд░рдиреЗ рдХреЗ рд▓рд┐рдП, рд╣рдо рдмрд┐рдирд╛ BERT рдХреЗ BiLSTM рд╕рд┐рдЦрд╛рдПрдВрдЧреЗред
рдПрдХ рддрдВрддреНрд░рд┐рдХрд╛ рдиреЗрдЯрд╡рд░реНрдХ рдХреЗ рдЗрдирдкреБрдЯ рдХреЛ рдкрд╛рда рдкреНрд░рд╕реНрддреБрдд рдХрд░рдиреЗ рдХреЗ рд▓рд┐рдП, рдЖрдкрдХреЛ рдЗрд╕реЗ рд╡реЗрдХреНрдЯрд░ рдХреЗ рд░реВрдк рдореЗрдВ рдкреНрд░рд╕реНрддреБрдд рдХрд░рдиреЗ рдХреА рдЖрд╡рд╢реНрдпрдХрддрд╛ рд╣реИред рд╕рдмрд╕реЗ рдЖрд╕рд╛рди рддрд░реАрдХреЛрдВ рдореЗрдВ рд╕реЗ рдПрдХ рд╢рдмреНрджрдХреЛрд╢ рдореЗрдВ рдкреНрд░рддреНрдпреЗрдХ рд╢рдмреНрдж рдХреЛ рдЙрд╕рдХреЗ рд╕реВрдЪрдХрд╛рдВрдХ рдореЗрдВ рдореИрдк рдХрд░рдирд╛ рд╣реИред рд╢рдмреНрджрдХреЛрд╢ рдореЗрдВ рд╣рдорд╛рд░реЗ рдбрд╛рдЯрд╛рд╕реЗрдЯ рдкреНрд▓рд╕ рджреЛ рд╕реЗрд╡рд╛ рд╢рдмреНрджреЛрдВ рдореЗрдВ рд╢реАрд░реНрд╖-рдПрди рд╕рдмрд╕реЗ рд▓реЛрдХрдкреНрд░рд┐рдп рд╢рдмреНрдж рд╢рд╛рдорд┐рд▓ рд╣реЛрдВрдЧреЗ: "рдкреИрдб" - "рдбрдореА рд╢рдмреНрдж" рддрд╛рдХрд┐ рд╕рднреА рдЕрдиреБрдХреНрд░рдо рдПрдХ рд╣реА рд▓рдВрдмрд╛рдИ рдХреЗ рд╣реЛрдВ, рдФрд░ рд╢рдмреНрджрдХреЛрд╢ рдХреЗ рдмрд╛рд╣рд░ рдХреЗ рд╢рдмреНрджреЛрдВ рдХреЗ рд▓рд┐рдП "рдЕрдирдХ"ред рд╣рдо рдорд╢рд╛рд▓ рдХреЗ рдЙрдкрдХрд░рдг рдХреЗ рдорд╛рдирдХ рд╕реЗрдЯ рдХрд╛ рдЙрдкрдпреЛрдЧ рдХрд░рдХреЗ рд╢рдмреНрджрдХреЛрд╢ рдХрд╛ рдирд┐рд░реНрдорд╛рдг рдХрд░реЗрдВрдЧреЗред рд╕рд╛рджрдЧреА рдХреЗ рд▓рд┐рдП, рдореИрдВрдиреЗ рдкреВрд░реНрд╡-рдкреНрд░рд╢рд┐рдХреНрд╖рд┐рдд рд╢рдмреНрдж рдПрдореНрдмреЗрдбрд┐рдВрдЧ рдХрд╛ рдЙрдкрдпреЛрдЧ рдирд╣реАрдВ рдХрд┐рдпрд╛ред
import torch from torchtext import data def get_vocab(X): X_split = [t.split() for t in X] text_field = data.Field() text_field.build_vocab(X_split, max_size=10000) return text_field def pad(seq, max_len): if len(seq) < max_len: seq = seq + ['<pad>'] * (max_len - len(seq)) return seq[0:max_len] def to_indexes(vocab, words): return [vocab.stoi[w] for w in words] def to_dataset(x, y, y_real): torch_x = torch.tensor(x, dtype=torch.long) torch_y = torch.tensor(y, dtype=torch.float) torch_real_y = torch.tensor(y_real, dtype=torch.long) return TensorDataset(torch_x, torch_y, torch_real_y)
рдореЙрдбрд▓ BiLSTM
рдореЙрдбрд▓ рдХреЗ рд▓рд┐рдП рдХреЛрдб рдЗрд╕ рддрд░рд╣ рджрд┐рдЦреЗрдЧрд╛:
import torch from torch import nn from torch.autograd import Variable class SimpleLSTM(nn.Module): def __init__(self, input_dim, embedding_dim, hidden_dim, output_dim, n_layers, bidirectional, dropout, batch_size, device=None): super(SimpleLSTM, self).__init__() self.batch_size = batch_size self.hidden_dim = hidden_dim self.n_layers = n_layers self.embedding = nn.Embedding(input_dim, embedding_dim) self.rnn = nn.LSTM(embedding_dim, hidden_dim, num_layers=n_layers, bidirectional=bidirectional, dropout=dropout) self.fc = nn.Linear(hidden_dim * 2, output_dim) self.dropout = nn.Dropout(dropout) self.device = self.init_device(device) self.hidden = self.init_hidden() @staticmethod def init_device(device): if device is None: return torch.device('cuda') return device def init_hidden(self): return (Variable(torch.zeros(2 * self.n_layers, self.batch_size, self.hidden_dim).to(self.device)), Variable(torch.zeros(2 * self.n_layers, self.batch_size, self.hidden_dim).to(self.device))) def forward(self, text, text_lengths=None): self.hidden = self.init_hidden() x = self.embedding(text) x, self.hidden = self.rnn(x, self.hidden) hidden, cell = self.hidden hidden = self.dropout(torch.cat((hidden[-2, :, :], hidden[-1, :, :]), dim=1)) x = self.fc(hidden) return x
рдЯреНрд░реЗрдирд┐рдВрдЧ
рдЗрд╕ рдореЙрдбрд▓ рдХреЗ рд▓рд┐рдП, рдЖрдЙрдЯрдкреБрдЯ рд╡реЗрдХреНрдЯрд░ рдХрд╛ рдЖрдпрд╛рдо (batch_size, output_dim) рд╣реЛрдЧрд╛ред рдкреНрд░рд╢рд┐рдХреНрд╖рдг рдореЗрдВ, рд╣рдо рд╕рд╛рдорд╛рдиреНрдп рд▓реЙрдЧрд▓реЙрд╕ рдХрд╛ рдЙрдкрдпреЛрдЧ рдХрд░реЗрдВрдЧреЗред PyTorch рдореЗрдВ BCEWithLogitsLoss рдХреНрд▓рд╛рд╕ рд╣реИ рдЬреЛ рд╕рд┐рдЧреНрдореЙрдЗрдб рдФрд░ рдХреНрд░реЙрд╕ рдПрдиреНрдЯреНрд░реЙрдкреА рдХреЛ рдЬреЛрдбрд╝рддреА рд╣реИред рдЖрдкрдХреЛ рдХреНрдпрд╛ рдЪрд╛рд╣рд┐рдП
def loss(self, output, bert_prob, real_label): criterion = torch.nn.BCEWithLogitsLoss() return criterion(output, real_label.float())
рд╕реАрдЦрдиреЗ рдХреЗ рдПрдХ рдпреБрдЧ рдХреЗ рд▓рд┐рдП рдХреЛрдб:
def get_optimizer(model): optimizer = torch.optim.Adam(model.parameters()) scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 2, gamma=0.9) return optimizer, scheduler def epoch_train_func(model, dataset, loss_func, batch_size): train_loss = 0 train_sampler = RandomSampler(dataset) data_loader = DataLoader(dataset, sampler=train_sampler, batch_size=batch_size, drop_last=True) model.train() optimizer, scheduler = get_optimizer(model) for i, (text, bert_prob, real_label) in enumerate(tqdm(data_loader, desc='Train')): text, bert_prob, real_label = to_device(text, bert_prob, real_label) model.zero_grad() output = model(text.t(), None).squeeze(1) loss = loss_func(output, bert_prob, real_label) loss.backward() optimizer.step() train_loss += loss.item() scheduler.step() return train_loss / len(data_loader)
рдпреБрдЧ рдХреЗ рдмрд╛рдж рд╕рддреНрдпрд╛рдкрди рдХреЗ рд▓рд┐рдП рдХреЛрдб:
def epoch_evaluate_func(model, eval_dataset, loss_func, batch_size): eval_sampler = SequentialSampler(eval_dataset) data_loader = DataLoader(eval_dataset, sampler=eval_sampler, batch_size=batch_size, drop_last=True) eval_loss = 0.0 model.eval() for i, (text, bert_prob, real_label) in enumerate(tqdm(data_loader, desc='Val')): text, bert_prob, real_label = to_device(text, bert_prob, real_label) output = model(text.t(), None).squeeze(1) loss = loss_func(output, bert_prob, real_label) eval_loss += loss.item() return eval_loss / len(data_loader)
рдпрджрд┐ рдпрд╣ рд╕рдм рдПрдХ рд╕рд╛рде рд░рдЦрд╛ рдЬрд╛рддрд╛ рд╣реИ, рддреЛ рд╣рдореЗрдВ рдореЙрдбрд▓ рдХреЗ рдкреНрд░рд╢рд┐рдХреНрд╖рдг рдХреЗ рд▓рд┐рдП рдирд┐рдореНрдирд▓рд┐рдЦрд┐рдд рдХреЛрдб рдорд┐рд▓рддреЗ рд╣реИрдВ:
import os import torch from torch.utils.data import (TensorDataset, random_split, RandomSampler, DataLoader, SequentialSampler) from torchtext import data from tqdm import tqdm def device(): return torch.device("cuda" if torch.cuda.is_available() else "cpu") def to_device(text, bert_prob, real_label): text = text.to(device()) bert_prob = bert_prob.to(device()) real_label = real_label.to(device()) return text, bert_prob, real_label class LSTMBaseline(object): vocab_name = 'text_vocab.pt' weights_name = 'simple_lstm.pt' def __init__(self, settings): self.settings = settings self.criterion = torch.nn.BCEWithLogitsLoss().to(device()) def loss(self, output, bert_prob, real_label): return self.criterion(output, real_label.float()) def model(self, text_field): model = SimpleLSTM( input_dim=len(text_field.vocab), embedding_dim=64, hidden_dim=128, output_dim=1, n_layers=1, bidirectional=True, dropout=0.5, batch_size=self.settings['train_batch_size']) return model def train(self, X, y, y_real, output_dir): max_len = self.settings['max_seq_length'] text_field = get_vocab(X) X_split = [t.split() for t in X] X_pad = [pad(s, max_len) for s in tqdm(X_split, desc='pad')] X_index = [to_indexes(text_field.vocab, s) for s in tqdm(X_pad, desc='to index')] dataset = to_dataset(X_index, y, y_real) val_len = int(len(dataset) * 0.1) train_dataset, val_dataset = random_split(dataset, (len(dataset) - val_len, val_len)) model = self.model(text_field) model.to(device()) self.full_train(model, train_dataset, val_dataset, output_dir) torch.save(text_field, os.path.join(output_dir, self.vocab_name)) def full_train(self, model, train_dataset, val_dataset, output_dir): train_settings = self.settings num_train_epochs = train_settings['num_train_epochs'] best_eval_loss = 100000 for epoch in range(num_train_epochs): train_loss = epoch_train_func(model, train_dataset, self.loss, self.settings['train_batch_size']) eval_loss = epoch_evaluate_func(model, val_dataset, self.loss, self.settings['eval_batch_size']) if eval_loss < best_eval_loss: best_eval_loss = eval_loss torch.save(model.state_dict(), os.path.join(output_dir, self.weights_name))
рдЖрд╕рд╡рди
рдЗрд╕ рдЖрд╕рд╡рди рд╡рд┐рдзрд┐ рдХрд╛ рд╡рд┐рдЪрд╛рд░ рд╡рд╛рдЯрд░рд▓реВ рд╡рд┐рд╢реНрд╡рд╡рд┐рджреНрдпрд╛рд▓рдп рдХреЗ рд╢реЛрдзрдХрд░реНрддрд╛рдУрдВ рджреНрд╡рд╛рд░рд╛ рдПрдХ рд▓реЗрдЦ рд╕реЗ рд▓рд┐рдпрд╛ рдЧрдпрд╛ рд╣реИ ред рдЬреИрд╕рд╛ рдХрд┐ рдореИрдВрдиреЗ рдКрдкрд░ рдХрд╣рд╛, "рдЫрд╛рддреНрд░" рдХреЛ "рд╢рд┐рдХреНрд╖рдХ" рдХреЗ рд╡реНрдпрд╡рд╣рд╛рд░ рдХреА рдирдХрд▓ рдХрд░рдирд╛ рд╕реАрдЦрдирд╛ рдЪрд╛рд╣рд┐рдПред рд╡рд╛рд╕реНрддрд╡ рдореЗрдВ рд╡реНрдпрд╡рд╣рд╛рд░ рдХреНрдпрд╛ рд╣реИ? рд╣рдорд╛рд░реЗ рдорд╛рдорд▓реЗ рдореЗрдВ, рдпреЗ рдкреНрд░рд╢рд┐рдХреНрд╖рдг рд╕реЗрдЯ рдкрд░ рд╢рд┐рдХреНрд╖рдХ рдореЙрдбрд▓ рдХреА рднрд╡рд┐рд╖реНрдпрд╡рд╛рдгрд┐рдпрд╛рдВ рд╣реИрдВред рдФрд░ рдореБрдЦреНрдп рд╡рд┐рдЪрд╛рд░ рд╕рдХреНрд░рд┐рдпрдг рдлрд╝рдВрдХреНрд╢рди рдХреЛ рд▓рд╛рдЧреВ рдХрд░рдиреЗ рд╕реЗ рдкрд╣рд▓реЗ рдиреЗрдЯрд╡рд░реНрдХ рдЖрдЙрдЯрдкреБрдЯ рдХрд╛ рдЙрдкрдпреЛрдЧ рдХрд░рдирд╛ рд╣реИред рдпрд╣ рдорд╛рдирд╛ рдЬрд╛рддрд╛ рд╣реИ рдХрд┐ рдЗрд╕ рддрд░рд╣ рд╕реЗ рдореЙрдбрд▓ рдЕрдВрддрд┐рдо рд╕рдВрднрд╛рд╡рдирд╛рдУрдВ рдХреЗ рдорд╛рдорд▓реЗ рдореЗрдВ рдЖрдВрддрд░рд┐рдХ рдкреНрд░рддрд┐рдирд┐рдзрд┐рддреНрд╡ рдХреЛ рдмреЗрд╣рддрд░ рдврдВрдЧ рд╕реЗ рд╕реАрдЦ рд╕рдХреЗрдЧрд╛ред
рдореВрд▓ рд▓реЗрдЦ рдореЗрдВ рд╣рд╛рдирд┐ рдлрд╝рдВрдХреНрд╢рди рдХреЗ рд▓рд┐рдП рдПрдХ рд╢рдмреНрдж рдЬреЛрдбрд╝рдиреЗ рдХрд╛ рдкреНрд░рд╕реНрддрд╛рд╡ рд╣реИ, рдЬреЛ "рд▓реЙрдЧ" рддреНрд░реБрдЯрд┐ рдХреЗ рд▓рд┐рдП рдЬрд┐рдореНрдореЗрджрд╛рд░ рд╣реЛрдЧрд╛ - рдореЙрдбрд▓ рд▓реЙрдЧ рдХреЗ рдмреАрдЪ рдПрдордПрд╕рдИред

рдЗрди рдЙрджреНрджреЗрд╢реНрдпреЛрдВ рдХреЗ рд▓рд┐рдП, рд╣рдо рджреЛ рдЫреЛрдЯреЗ рдмрджрд▓рд╛рд╡ рдХрд░рддреЗ рд╣реИрдВ: 1 рд╕реЗ 2 рддрдХ рдиреЗрдЯрд╡рд░реНрдХ рдЖрдЙрдЯрдкреБрдЯ рдХреА рд╕рдВрдЦреНрдпрд╛ рдХреЛ рдмрджрд▓реЗрдВ рдФрд░ рдиреБрдХрд╕рд╛рди рдлрд╝рдВрдХреНрд╢рди рдХреЛ рдареАрдХ рдХрд░реЗрдВред
def loss(self, output, bert_prob, real_label): a = 0.5 criterion_mse = torch.nn.MSELoss() criterion_ce = torch.nn.CrossEntropyLoss() return a*criterion_ce(output, real_label) + (1-a)*criterion_mse(output, bert_prob)
рдЖрдк рдХреЗрд╡рд▓ рдореЙрдбрд▓ рдФрд░ рд╣рд╛рдирд┐ рдХреЛ рдлрд┐рд░ рд╕реЗ рдкрд░рд┐рднрд╛рд╖рд┐рдд рдХрд░рдХреЗ рд▓рд┐рдЦреЗ рдЧрдП рд╕рднреА рдХреЛрдб рдХрд╛ рдкреБрди: рдЙрдкрдпреЛрдЧ рдХрд░ рд╕рдХрддреЗ рд╣реИрдВ:
class LSTMDistilled(LSTMBaseline): vocab_name = 'distil_text_vocab.pt' weights_name = 'distil_lstm.pt' def __init__(self, settings): super(LSTMDistilled, self).__init__(settings) self.criterion_mse = torch.nn.MSELoss() self.criterion_ce = torch.nn.CrossEntropyLoss() self.a = 0.5 def loss(self, output, bert_prob, real_label): return self.a * self.criterion_ce(output, real_label) + (1 - self.a) * self.criterion_mse(output, bert_prob) def model(self, text_field): model = SimpleLSTM( input_dim=len(text_field.vocab), embedding_dim=64, hidden_dim=128, output_dim=2, n_layers=1, bidirectional=True, dropout=0.5, batch_size=self.settings['train_batch_size']) return model
рдмрд╕ рдЗрддрдирд╛ рд╣реА, рдЕрдм рд╣рдорд╛рд░рд╛ рдореЙрдбрд▓ "рдирдХрд▓" рдХрд░рдирд╛ рд╕реАрдЦ рд░рд╣рд╛ рд╣реИред
рдореЙрдбрд▓ рддреБрд▓рдирд╛
рдореВрд▓ рд▓реЗрдЦ рдореЗрдВ, SST-2 рдХреЗ рд▓рд┐рдП рд╕рд░реНрд╡рд╢реНрд░реЗрд╖реНрда рд╡рд░реНрдЧреАрдХрд░рдг рдкрд░рд┐рдгрд╛рдо = 0 рдкрд░ рдкреНрд░рд╛рдкреНрдд рдХрд┐рдП рдЬрд╛рддреЗ рд╣реИрдВ, рдЬрдм рдореЙрдбрд▓ рдХреЗрд╡рд▓ рдирдХрд▓ рдХрд░рдирд╛ рд╕реАрдЦрддрд╛ рд╣реИ, рд╡рд╛рд╕реНрддрд╡рд┐рдХ рд▓реЗрдмрд▓ рдХреЛ рдзреНрдпрд╛рди рдореЗрдВ рдирд╣реАрдВ рд░рдЦрддрд╛ рд╣реИред рд╕рдЯреАрдХрддрд╛ рдЕрднреА рднреА BERT рд╕реЗ рдХрдо рд╣реИ, рд▓реЗрдХрд┐рди рдирд┐рдпрдорд┐рдд BiLSTM рд╕реЗ рдХрд╛рдлреА рдмреЗрд╣рддрд░ рд╣реИред

рдореИрдВрдиреЗ рд▓реЗрдЦ рд╕реЗ рдкрд░рд┐рдгрд╛рдо рджреЛрд╣рд░рд╛рдиреЗ рдХреА рдХреЛрд╢рд┐рд╢ рдХреА, рд▓реЗрдХрд┐рди рдореЗрд░реЗ рдкреНрд░рдпреЛрдЧреЛрдВ рдореЗрдВ рд╕рд░реНрд╡реЛрддреНрддрдо рдкрд░рд┐рдгрд╛рдо = 0.5 рдкрд░ рдкреНрд░рд╛рдкреНрдд рд╣реБрдЖред
рд╕рд╛рдорд╛рдиреНрдп рддрд░реАрдХреЗ рд╕реЗ LSTM рд╕реАрдЦрддреЗ рд╕рдордп рдпрд╣ рд╣рд╛рдирд┐ рдФрд░ рд╕рдЯреАрдХрддрд╛ рд░реЗрдЦрд╛рдВрдХрди рджрд┐рдЦрддрд╛ рд╣реИред рдиреБрдХрд╕рд╛рди рдХреЗ рд╡реНрдпрд╡рд╣рд╛рд░ рдХреЛ рджреЗрдЦрддреЗ рд╣реБрдП, рдореЙрдбрд▓ рдиреЗ рддреЗрдЬреА рд╕реЗ рд╕реАрдЦрд╛, рдФрд░ рдХрд╣реАрдВ рди рдХрд╣реАрдВ рдЫрдареЗ рдпреБрдЧ рдХреЗ рдмрд╛рдж, рдлрд┐рд░ рд╕реЗ рджреЗрдЦрдирд╛ рд╢реБрд░реВ рд╣реБрдЖред

рдЖрд╕рд╡рди рд░реЗрдЦрд╛рдВрдХрди:

рдбрд┐рд╕реНрдЯрд┐рд▓реНрдб BiLSTM рд╕рд╛рдорд╛рдиреНрдп рд╕реЗ рд▓рдЧрд╛рддрд╛рд░ рдмреЗрд╣рддрд░ рд╣реИред рдпрд╣ рдорд╣рддреНрд╡рдкреВрд░реНрдг рд╣реИ рдХрд┐ рд╡реЗ рд╡рд╛рд╕реНрддреБрдХрд▓рд╛ рдореЗрдВ рдмрд┐рд▓реНрдХреБрд▓ рд╕рдорд╛рди рд╣реИрдВ, рдПрдХрдорд╛рддреНрд░ рдЕрдВрддрд░ рд╢рд┐рдХреНрд╖рдг рдХреЗ рддрд░реАрдХреЗ рдореЗрдВ рд╣реИред рдореИрдВрдиреЗ GitHub рдкрд░ рдкреВрд░реНрдг рдкреНрд░рд╢рд┐рдХреНрд╖рдг рдХреЛрдб рдкреЛрд╕реНрдЯ рдХрд┐рдпрд╛ ред
рдирд┐рд╖реНрдХрд░реНрд╖
рдЗрд╕ рдЧрд╛рдЗрдб рдореЗрдВ, рдореИрдВрдиреЗ рдЖрд╕рд╡рди рджреГрд╖реНрдЯрд┐рдХреЛрдг рдХреЗ рдореВрд▓ рд╡рд┐рдЪрд╛рд░ рдХреЛ рд╕рдордЭрд╛рдиреЗ рдХреА рдХреЛрд╢рд┐рд╢ рдХреАред рдЫрд╛рддреНрд░ рдХреА рд╡рд┐рд╢рд┐рд╖реНрдЯ рд╡рд╛рд╕реНрддреБрдХрд▓рд╛ рд╣рд╛рде рдореЗрдВ рдХрд╛рдо рдкрд░ рдирд┐рд░реНрднрд░ рдХрд░реЗрдЧреАред рд▓реЗрдХрд┐рди рд╕рд╛рдорд╛рдиреНрдп рддреМрд░ рдкрд░, рдпрд╣ рджреГрд╖реНрдЯрд┐рдХреЛрдг рдХрд┐рд╕реА рднреА рд╡реНрдпрд╛рд╡рд╣рд╛рд░рд┐рдХ рдХрд╛рд░реНрдп рдореЗрдВ рд▓рд╛рдЧреВ рд╣реЛрддрд╛ рд╣реИред рдореЙрдбрд▓ рдкреНрд░рд╢рд┐рдХреНрд╖рдг рдХреЗ рдЪрд░рдг рдореЗрдВ рдЬрдЯрд┐рд▓рддрд╛ рдХреЗ рдХрд╛рд░рдг, рдЖрдк рд╡рд╛рд╕реНрддреБрдХрд▓рд╛ рдХреА рдореВрд▓ рд╕рд╛рджрдЧреА рдХреЛ рдмрдирд╛рдП рд░рдЦрддреЗ рд╣реБрдП, рдЗрд╕рдХреА рдЧреБрдгрд╡рддреНрддрд╛ рдореЗрдВ рдЙрд▓реНрд▓реЗрдЦрдиреАрдп рд╡реГрджреНрдзрд┐ рдкреНрд░рд╛рдкреНрдд рдХрд░ рд╕рдХрддреЗ рд╣реИрдВред