рд▓реЗрдЦ рд░реВрд╕реА рдореЗрдВ рдкрд╛рда рд╕рдВрджреЗрд╢реЛрдВ рдХреЗ рдЯрди рдХреА рд╡рд░реНрдЧреАрдХрд░рдг рдкрд░ рдЪрд░реНрдЪрд╛ рдХрд░реЗрдЧрд╛ (рдФрд░ рдЕрдирд┐рд╡рд╛рд░реНрдп рд░реВрдк рд╕реЗ рдЙрд╕реА рддрдХрдиреАрдХ рдХрд╛ рдЙрдкрдпреЛрдЧ рдХрд░рдХреЗ рдЧреНрд░рдВрдереЛрдВ рдХреЗ рдХрд┐рд╕реА рднреА рд╡рд░реНрдЧреАрдХрд░рдг)ред рд╣рдо
рдЗрд╕ рд▓реЗрдЦ рдХреЛ рдПрдХ рдЖрдзрд╛рд░ рдХреЗ рд░реВрдк рдореЗрдВ рд▓реЗрдВрдЧреЗ, рдЬрд┐рд╕рдореЗрдВ Word2vec рдореЙрдбрд▓ рдХрд╛ рдЙрдкрдпреЛрдЧ рдХрд░рдХреЗ CNN рдЖрд░реНрдХрд┐рдЯреЗрдХреНрдЪрд░ рдкрд░ рдЖрдЬ рдХреА рд░рд╛рдд рдХреЗ рд╡рд░реНрдЧреАрдХрд░рдг рдкрд░ рд╡рд┐рдЪрд╛рд░ рдХрд┐рдпрд╛ рдЧрдпрд╛ рдерд╛ред рд╣рдорд╛рд░реЗ рдЙрджрд╛рд╣рд░рдг рдореЗрдВ, рд╣рдо
ULMFit рдореЙрдбрд▓ рдХрд╛ рдЙрдкрдпреЛрдЧ рдХрд░рдХреЗ рдПрдХ рд╣реА рдбреЗрдЯрд╛рд╕реЗрдЯ рдкрд░ рд╕рдХрд╛рд░рд╛рддреНрдордХ рдФрд░ рдирдХрд╛рд░рд╛рддреНрдордХ рдореЗрдВ рдЯреНрд╡реАрдЯреНрд╕ рдХреЛ рдЕрд▓рдЧ рдХрд░рдиреЗ рдХреА рдПрдХ рд╣реА рд╕рдорд╕реНрдпрд╛ рдХреЛ рд╣рд▓ рдХрд░реЗрдВрдЧреЗред рд▓реЗрдЦ рд╕реЗ рдкрд░рд┐рдгрд╛рдо (рдФрд╕рдд рдПрдл 1-рд╕реНрдХреЛрд░ = 0.78142) рдХреЛ рдЖрдзрд╛рд░ рд░реЗрдЦрд╛ рдХреЗ рд░реВрдк рдореЗрдВ рд╕реНрд╡реАрдХрд╛рд░ рдХрд┐рдпрд╛ рдЬрд╛рдПрдЧрд╛ред
рдкрд░рд┐рдЪрдп
ULMFIT рдореЙрдбрд▓ рдХреЛ 2018 рдореЗрдВ fast.ai Developers (рдЬреЗрд░реЗрдореА рд╣реЙрд╡рд░реНрдб, рд╕реЗрдмреЗрд╕реНрдЯрд┐рдпрди рд░реБрдбрд░) рджреНрд╡рд╛рд░рд╛ рдкреЗрд╢ рдХрд┐рдпрд╛ рдЧрдпрд╛ рдерд╛ред рджреГрд╖реНрдЯрд┐рдХреЛрдг рдХрд╛ рд╕рд╛рд░ рдПрдирдПрд▓рдкреА рдХрд╛рд░реНрдпреЛрдВ рдореЗрдВ рдЯреНрд░рд╛рдВрд╕рдлрд░ рд▓рд░реНрдирд┐рдВрдЧ рдХрд╛ рдЙрдкрдпреЛрдЧ рдХрд░рдирд╛ рд╣реИ рдЬрдм рдЖрдк рдкреВрд░реНрд╡-рдкреНрд░рд╢рд┐рдХреНрд╖рд┐рдд рдореЙрдбрд▓ рдХрд╛ рдЙрдкрдпреЛрдЧ рдХрд░рддреЗ рд╣реИрдВ, рдЕрдкрдиреЗ рдореЙрдбрд▓ рдХреЗ рдкреНрд░рд╢рд┐рдХреНрд╖рдг рдХреЗ рд▓рд┐рдП рд╕рдордп рдХрдо рдХрд░рддреЗ рд╣реИрдВ рдФрд░ рд▓реЗрдмрд▓ рдХрд┐рдП рдЧрдП рдкрд░реАрдХреНрд╖рдг рдирдореВрдиреЗ рдХреЗ рдЖрдХрд╛рд░ рдХреЗ рд▓рд┐рдП рдЖрд╡рд╢реНрдпрдХрддрд╛рдУрдВ рдХреЛ рдХрдо рдХрд░рддреЗ рд╣реИрдВред
рд╣рдорд╛рд░реЗ рдорд╛рдорд▓реЗ рдореЗрдВ рдкреНрд░рд╢рд┐рдХреНрд╖рдг рдпреЛрдЬрдирд╛ рдЗрд╕ рдкреНрд░рдХрд╛рд░ рд╣реЛрдЧреА:

рднрд╛рд╖рд╛ рдореЙрдбрд▓ рдХрд╛ рдЕрд░реНрде рдЕрдиреБрдХреНрд░рдо рдореЗрдВ рдЕрдЧрд▓реЗ рд╢рдмреНрдж рдХреА рднрд╡рд┐рд╖реНрдпрд╡рд╛рдгреА рдХрд░рдиреЗ рдореЗрдВ рд╕рдХреНрд╖рдо рд╣реЛрдирд╛ рд╣реИред рдЗрд╕ рддрд░рд╣ рд╕реЗ рд▓рдВрдмреЗ рд╕рдордп рддрдХ рдЬреБрдбрд╝реЗ рд╣реБрдП рдЧреНрд░рдВрдереЛрдВ рдХреЛ рдкреНрд░рд╛рдкреНрдд рдХрд░рдирд╛ рд╕рдорд╕реНрдпрд╛рдЧреНрд░рд╕реНрдд рд╣реИ, рд▓реЗрдХрд┐рди рдлрд┐рд░ рднреА, рднрд╛рд╖рд╛ рдореЙрдбрд▓ рднрд╛рд╖рд╛ рдХреЗ рдЧреБрдгреЛрдВ рдХреЛ рдкрдХрдбрд╝рдиреЗ рдореЗрдВ рд╕рдХреНрд╖рдо рд╣реИрдВ, рд╢рдмреНрджреЛрдВ рдХреЗ рдЙрдкрдпреЛрдЧ рдХреЗ рд╕рдВрджрд░реНрдн рдХреЛ рд╕рдордЭрддреЗ рд╣реИрдВ, рдЗрд╕рд▓рд┐рдП рдпрд╣ рднрд╛рд╖рд╛ рдореЙрдбрд▓ рд╣реИ (рдФрд░ рдирд╣реАрдВ, рдЙрджрд╛рд╣рд░рдг рдХреЗ рд▓рд┐рдП, рд╢рдмреНрджреЛрдВ рдХрд╛ рд╡реЗрдХреНрдЯрд░ рдкреНрд░рджрд░реНрд╢рди) рдЬреЛ рдХрд┐ рдкреНрд░реМрджреНрдпреЛрдЧрд┐рдХреА рдХрд╛ рдЖрдзрд╛рд░ рд╣реИред рднрд╛рд╖рд╛ рдХреЗ рдореЙрдбрд▓рд┐рдВрдЧ рдХреЗ рдХрд╛рд░реНрдп рдХреЗ рд▓рд┐рдП, ULMFit
AWD-LSTM рдЖрд░реНрдХрд┐рдЯреЗрдХреНрдЪрд░ рдХрд╛ рдЙрдкрдпреЛрдЧ рдХрд░рддрд╛ рд╣реИ, рдЬрд┐рд╕рдореЗрдВ рдЬрд╣рд╛рдВ рднреА рд╕рдВрднрд╡ рд╣реЛ, рдбреНрд░реЙрдкрдЖрдЙрдЯ рдХрд╛ рд╕рдХреНрд░рд┐рдп рдЙрдкрдпреЛрдЧ рд╢рд╛рдорд┐рд▓ рд╣реИ рдФрд░ рд╕рдордЭ рдореЗрдВ рдЖрддрд╛ рд╣реИред рднрд╛рд╖рд╛ рдореЙрдбрд▓ рдкреНрд░рд╢рд┐рдХреНрд╖рдг рдХреЗ рдкреНрд░рдХрд╛рд░ рдХреЛ рдХрднреА-рдХрднреА рдЕрд░реНрдз-рдкрд░реНрдпрд╡реЗрдХреНрд╖рдгреАрдп рд╢рд┐рдХреНрд╖рд╛ рдХрд╣рд╛ рдЬрд╛рддрд╛ рд╣реИ, рдХреНрдпреЛрдВрдХрд┐ рдпрд╣рд╛рдВ рд▓реЗрдмрд▓ рдЕрдЧрд▓реЗ рд╢рдмреНрдж рд╣реИ рдФрд░ рдХреБрдЫ рднреА рдЖрдкрдХреЗ рд╣рд╛рдереЛрдВ рд╕реЗ рдЪрд┐рд╣реНрдирд┐рдд рдХрд░рдиреЗ рдХреА рдЖрд╡рд╢реНрдпрдХрддрд╛ рдирд╣реАрдВ рд╣реИред
рдкреВрд░реНрд╡-рдкреНрд░рд╢рд┐рдХреНрд╖рд┐рдд рднрд╛рд╖рд╛ рдореЙрдбрд▓ рдХреЗ рд░реВрдк рдореЗрдВ, рд╣рдо рд▓рдЧрднрдЧ рдПрдХрдорд╛рддреНрд░
рдЙрдкрд▓рдмреНрдз рд╕рд╛рд░реНрд╡рдЬрдирд┐рдХ рд░реВрдк рд╕реЗ рдЙрдкрдпреЛрдЧ рдХрд░реЗрдВрдЧреЗред
рдЖрдЗрдП рд╢реБрд░реВ рд╕реЗ рд╣реА рд╕реАрдЦрдиреЗ рдХреЗ рдПрд▓реНрдЧреЛрд░рд┐рдереНрдо рдХреЗ рдорд╛рдзреНрдпрдо рд╕реЗ рдЪрд▓рддреЗ рд╣реИрдВред
рд╣рдо рдкреБрд╕реНрддрдХрд╛рд▓рдпреЛрдВ рдХреЛ рд▓реЛрдб рдХрд░рддреЗ рд╣реИрдВ (рд╣рдо рдХрд┐рд╕реА рднреА рдЕрд╕рдВрдЧрддрддрд╛ рдХреЗ рдорд╛рдорд▓реЗ рдореЗрдВ Fast.ai рдХреЗ рд╕рдВрд╕реНрдХрд░рдг рдХреА рдЬрд╛рдВрдЪ рдХрд░рддреЗ рд╣реИрдВ):
%load_ext autoreload %autoreload 2 import pandas as pd import numpy as np import re import statistics import fastai print('fast.ai version is:', fastai.__version__) from fastai import * from fastai.text import * from sklearn.model_selection import train_test_split path = ''
Out: fast.ai version is: 1.0.58
рд╣рдо рдкреНрд░рд╢рд┐рдХреНрд╖рдг рдХреЗ рд▓рд┐рдП рдбреЗрдЯрд╛ рддреИрдпрд╛рд░ рдХрд░рддреЗ рд╣реИрдВ
рд╕рд╛рджреГрд╢реНрдп рд╕реЗ, рд╣рдо
рдпреВрд▓рд┐рдпрд╛ рд░реБрдмрддрд╕реЛрд╡рд╛ рджреНрд╡рд╛рд░рд╛ рд▓рдШреБ рдЧреНрд░рдВрде RuTweetCorp рдХреЗ
рд╢рд░реАрд░ рдкрд░ рдкреНрд░рд╢рд┐рдХреНрд╖рдг рдХрд╛ рдЖрдпреЛрдЬрди рдХрд░реЗрдВрдЧреЗ, рдЬреЛ рдЯреНрд╡рд┐рдЯрд░ рд╕реЗ рд░реВрд╕реА-рднрд╛рд╖рд╛ рдХреЗ рд╕рдВрджреЗрд╢реЛрдВ рдХреЗ рдЖрдзрд╛рд░ рдкрд░ рдмрдирд╛рдпрд╛ рдЧрдпрд╛ рд╣реИред рд╢рд░реАрд░ рдореЗрдВ 114,991 рд╕рдХрд╛рд░рд╛рддреНрдордХ рдЯреНрд╡реАрдЯ рдФрд░ CSV рдкреНрд░рд╛рд░реВрдк рдореЗрдВ 111,923 рдирдХрд╛рд░рд╛рддреНрдордХ рдЯреНрд╡реАрдЯ рд╣реИрдВред рдЗрд╕рдХреЗ рдЕрд▓рд╛рд╡рд╛, SQL рдкреНрд░рд╛рд░реВрдк рдореЗрдВ 17 639 674 рд░рд┐рдХреЙрд░реНрдб рдХреА рдорд╛рддреНрд░рд╛ рдХреЗ рд╕рд╛рде рдЕрд╕рдВрдмрджреНрдз рдЯреНрд╡реАрдЯреНрд╕ рдХрд╛ рдПрдХ рдбреЗрдЯрд╛рдмреЗрд╕ рд╣реИред рд╣рдорд╛рд░реЗ рд╡рд░реНрдЧреАрдХрд░рдг рдХрд╛ рдХрд╛рд░реНрдп рдпрд╣ рдирд┐рд░реНрдзрд╛рд░рд┐рдд рдХрд░рдирд╛ рд╣реЛрдЧрд╛ рдХрд┐ рдЯреНрд╡реАрдЯ рд╕рдХрд╛рд░рд╛рддреНрдордХ рд╣реИ рдпрд╛ рдирдХрд╛рд░рд╛рддреНрдордХред
рдЪреВрдБрдХрд┐
17 рдорд┐рд▓рд┐рдпрди рдЯреНрд╡реАрдЯреНрд╕ рдкрд░ рднрд╛рд╖рд╛ рдореЙрдбрд▓ рдХреЛ рдлрд┐рд░ рд╕реЗ рд▓рд┐рдЦрдирд╛ рдерд╛ рдФрд░ рдЯреНрд░рд╛рдВрд╕рдлрд░ рд▓рд░реНрдирд┐рдВрдЧ рдХреЛ рджрд┐рдЦрд╛рдиреЗ рдХреЗ рд▓рд┐рдП рдХрд╛рд░реНрдп рдХрд░рдирд╛ рдерд╛, рдЗрд╕рд▓рд┐рдП рд╣рдо рдкреНрд░рд╢рд┐рдХреНрд╖рдг рдбреЗрдЯрд╛рд╕реЗрдЯ рд╕реЗ рдкрд╛рда рдХреЗ рдПрдХ рдЯреБрдХрдбрд╝реЗ рдкрд░ рднрд╛рд╖рд╛ рдореЙрдбрд▓ рдХреЛ рдлрд┐рд░ рд╕реЗ рд▓рд┐рдЦрдирд╛ рдЪрд╛рд╣рддреЗ рд╣реИрдВ, рдкреВрд░реА рддрд░рд╣ рд╕реЗ рдЕрд╕рдВрдмрджреНрдз рдЯреНрд╡реАрдЯ рдХреЗ рдЖрдзрд╛рд░ рдХреЛ рдЕрдирджреЗрдЦрд╛ рдХрд░ рд░рд╣реЗ рд╣реИрдВред рд╕рдВрднрд╡рддрдГ, рднрд╛рд╖рд╛ рдореЙрдбрд▓ рдХреЛ "рддреЗрдЬ" рдХрд░рдиреЗ рдХреЗ рд▓рд┐рдП рдЗрд╕ рдЖрдзрд╛рд░ рдХрд╛ рдЙрдкрдпреЛрдЧ рдХрд░рдХреЗ, рдЖрдк рд╕рдордЧреНрд░ рдкрд░рд┐рдгрд╛рдо рдореЗрдВ рд╕реБрдзрд╛рд░ рдХрд░ рд╕рдХрддреЗ рд╣реИрдВред
рд╣рдо рдкреНрд░рд╛рд░рдВрднрд┐рдХ рд╢рдмреНрдж рдкреНрд░рд╕рдВрд╕реНрдХрд░рдг рдХреЗ рд╕рд╛рде рдкреНрд░рд╢рд┐рдХреНрд╖рдг рдФрд░ рдкрд░реАрдХреНрд╖рдг рдХреЗ рд▓рд┐рдП рдбреЗрдЯрд╛рд╕реЗрдЯ рдмрдирд╛рддреЗ рд╣реИрдВред рд╣рдо
рдореВрд▓ рд▓реЗрдЦ рд╕реЗ рдХреЛрдб рд▓реЗрддреЗ рд╣реИрдВ:
def preprocess_text(text): text = text.lower().replace("", "") text = re.sub('((www\.[^\s]+)|(https?://[^\s]+))', 'URL', text) text = re.sub('@[^\s]+', 'USER', text) text = re.sub('[^a-zA-Z--1-9]+', ' ', text) text = re.sub(' +', ' ', text) return text.strip() data = [preprocess_text(t) for t in raw_data]
df_train=pd.DataFrame(columns=['Text', 'Label']) df_test=pd.DataFrame(columns=['Text', 'Label']) df_train['Text'], df_test['Text'], df_train['Label'], df_test['Label'] = train_test_split(data, labels, test_size=0.2, random_state=1)
df_val=pd.DataFrame(columns=['Text', 'Label']) df_train, df_val = train_test_split(df_train, test_size=0.2, random_state=1)
рд╣рдо рджреЗрдЦрддреЗ рд╣реИрдВ рдХрд┐ рдХреНрдпрд╛ рд╣реБрдЖ:
df_train.groupby('Label').count()

df_val.groupby('Label').count()

df_test.groupby('Label').count()

рдПрдХ рднрд╛рд╖рд╛ рдореЙрдбрд▓ рд╕реАрдЦрдирд╛
рдбреЗрдЯрд╛ рд▓реЛрдб рд╣реЛ рд░рд╣рд╛ рд╣реИ:
tokenizer=Tokenizer(lang='xx') data_lm = TextLMDataBunch.from_df(path, tokenizer=tokenizer, bs=16, train_df=df_train, valid_df=df_val, text_cols=0)
рд╣рдо рд╕рд╛рдордЧреНрд░реА рдХреЛ рджреЗрдЦрддреЗ рд╣реИрдВ:
data_lm.show_batch()

рд╣рдо
рдкреВрд░реНрд╡ рдкреНрд░рд╢рд┐рдХреНрд╖рд┐рдд рдореЙрдбрд▓ рдФрд░ рдПрдХ рд╢рдмреНрджрдХреЛрд╢ рдХреЗ рд╕рдВрдЧреНрд░рд╣реАрдд рднрд╛рд░ рдХреЛ рд▓рд┐рдВрдХ рдкреНрд░рджрд╛рди рдХрд░рддреЗ рд╣реИрдВ:
weights_pretrained = 'ULMFit/lm_5_ep_lr2-3_5_stlr' itos_pretrained = 'ULMFit/itos' pretained_data = (weights_pretrained, itos_pretrained)
рд╣рдо рд╢рд┐рдХреНрд╖рд╛рд░реНрдереА рдмрдирд╛рддреЗ рд╣реИрдВ, рд▓реЗрдХрд┐рди рдЙрд╕рд╕реЗ рдкрд╣рд▓реЗ - рдЙрдкрд╡рд╛рд╕ рдХреЗ рд▓рд┐рдП рдПрдХ рдмреИрд╕рд╛рдЦреАред рдкреВрд░реНрд╡-рдкреНрд░рд╢рд┐рдХреНрд╖рд┐рдд рдореЙрдбрд▓ рдкреБрд╕реНрддрдХрд╛рд▓рдп рдХреЗ рдПрдХ рдкреБрд░рд╛рдиреЗ рд╕рдВрд╕реНрдХрд░рдг рдкрд░ рдкреНрд░рд╢рд┐рдХреНрд╖рд┐рдд рдХрд┐рдпрд╛ рдЧрдпрд╛ рдерд╛, рдЗрд╕рд▓рд┐рдП рдЖрдкрдХреЛ рддрдВрддреНрд░рд┐рдХрд╛ рдиреЗрдЯрд╡рд░реНрдХ рдХреА рдЫрд┐рдкреА рд╣реБрдИ рдкрд░рдд рдореЗрдВ рдиреЛрдбреНрд╕ рдХреА рд╕рдВрдЦреНрдпрд╛ рдХреЛ рд╕рдорд╛рдпреЛрдЬрд┐рдд рдХрд░рдиреЗ рдХреА рдЖрд╡рд╢реНрдпрдХрддрд╛ рд╣реИред
config = awd_lstm_lm_config.copy() config['n_hid'] = 1150 learn_lm = language_model_learner(data_lm, AWD_LSTM, config=config, pretrained_fnames=pretained_data, drop_mult=0.3) learn_lm.freeze()
рд╣рдо рдЗрд╖реНрдЯрддрдо рд╕реАрдЦрдиреЗ рдХреА рджрд░ рдХреА рддрд▓рд╛рд╢ рдХрд░ рд░рд╣реЗ рд╣реИрдВ:
learn_lm.lr_find() learn_lm.recorder.plot()

рд╣рдо рддреАрд╕рд░реЗ рдпреБрдЧ рдХреЗ рдореЙрдбрд▓ рдХреЛ рдкреНрд░рд╢рд┐рдХреНрд╖рд┐рдд рдХрд░рддреЗ рд╣реИрдВ (рдореЙрдбрд▓ рдореЗрдВ, рдХреЗрд╡рд▓ рдкрд░рддреЛрдВ рдХрд╛ рдЕрдВрддрд┐рдо рд╕рдореВрд╣ рдЕрдкрд░рд┐рд╡рд░реНрддрдиреАрдп рд╣реИ)ред
learn_lm.fit_one_cycle(3, 1e-2, moms=(0.8, 0.7))

рдореЙрдбрд▓ рдХреЛ рдкрд░рд┐рднрд╛рд╖рд┐рдд рдХрд░рдирд╛, рдХрдо рд╕реАрдЦрдиреЗ рдХреА рджрд░ рдХреЗ рд╕рд╛рде 5 рдФрд░ рдпреБрдЧреЛрдВ рдХреЛ рдкрдврд╝рд╛рдирд╛:
learn_lm.unfreeze() learn_lm.fit_one_cycle(5, 1e-3, moms=(0.8, 0.7))

learn_lm.save('lm_ft')
рд╣рдо рдПрдХ рдкреНрд░рд╢рд┐рдХреНрд╖рд┐рдд рдореЙрдбрд▓ рдкрд░ рдкрд╛рда рдЙрддреНрдкрдиреНрди рдХрд░рдиреЗ рдХрд╛ рдкреНрд░рдпрд╛рд╕ рдХрд░рддреЗ рд╣реИрдВред
learn_lm.predict(" ", n_words=5)
Out: ' '
learn_lm.predict(", ", n_words=4)
Out: ', '
рд╣рдо рджреЗрдЦрддреЗ рд╣реИрдВ - рдХреБрдЫ рдРрд╕рд╛ рдЬреЛ рдореЙрдбрд▓ рдХрд░рддрд╛ рд╣реИред рд▓реЗрдХрд┐рди рд╣рдорд╛рд░рд╛ рдореБрдЦреНрдп рдХрд╛рд░реНрдп рд╡рд░реНрдЧреАрдХрд░рдг рд╣реИ, рдФрд░ рдЗрд╕рдХреЗ рд╕рдорд╛рдзрд╛рди рдХреЗ рд▓рд┐рдП рд╣рдо рдореЙрдбрд▓ рд╕реЗ рдПрдХ рдПрдирдХреЛрдбрд░ рд▓реЗрдВрдЧреЗред
learn_lm.save_encoder('ft_enc')
рд╣рдо рдХреНрд▓рд╛рд╕рд┐рдлрд╛рдпрд░ рдЯреНрд░реЗрди рдХрд░рддреЗ рд╣реИрдВ
рдкреНрд░рд╢рд┐рдХреНрд╖рдг рдХреЗ рд▓рд┐рдП рдбреЗрдЯрд╛ рдбрд╛рдЙрдирд▓реЛрдб рдХрд░реЗрдВ
data_clas = TextClasDataBunch.from_df(path, vocab=data_lm.train_ds.vocab, bs=32, train_df=df_train, valid_df=df_val, text_cols=0, label_cols=1, tokenizer=tokenizer)
рдЖрдЗрдП рдбреЗрдЯрд╛ рдХреЛ рджреЗрдЦреЗрдВ, рд╣рдо рджреЗрдЦрддреЗ рд╣реИрдВ рдХрд┐ рд▓реЗрдмрд▓ рд╕рдлрд▓рддрд╛рдкреВрд░реНрд╡рдХ рдЧрд┐рдиреЗ рдЧрдП рдереЗ (0 рдХрд╛ рдорддрд▓рдм рдирдХрд╛рд░рд╛рддреНрдордХ рд╣реИ, рдФрд░ 1 рдХрд╛ рдорддрд▓рдм рд╕рдХрд╛рд░рд╛рддреНрдордХ рдЯрд┐рдкреНрдкрдгреА рд╣реИ):
data_clas.show_batch()

рдПрдХ рд╕рдорд╛рди рдмреИрд╕рд╛рдЦреА рдХреЗ рд╕рд╛рде рдПрдХ рд╢рд┐рдХреНрд╖рд╛рд░реНрдереА рдмрдирд╛рдПрдБ:
config = awd_lstm_clas_config.copy() config['n_hid'] = 1150 learn = text_classifier_learner(data_clas, AWD_LSTM, config=config, drop_mult=0.5)
рд╣рдо рдкрд┐рдЫрд▓реЗ рдЪрд░рдг рдореЗрдВ рдкреНрд░рд╢рд┐рдХреНрд╖рд┐рдд рдПрдирдХреЛрдбрд░ рдХреЛ рд▓реЛрдб рдХрд░рддреЗ рд╣реИрдВ рдФрд░ рдореЙрдбрд▓ рдХреЛ рднрд╛рд░ рдХреЗ рдЕрдВрддрд┐рдо рд╕рдореВрд╣ рдХреЛ рдЫреЛрдбрд╝рдХрд░ рдлреНрд░реАрдЬ рдХрд░рддреЗ рд╣реИрдВ:
learn.load_encoder('ft_enc') learn.freeze()
рд╣рдо рдЗрд╖реНрдЯрддрдо рд╕реАрдЦрдиреЗ рдХреА рджрд░ рдХреА рддрд▓рд╛рд╢ рдХрд░ рд░рд╣реЗ рд╣реИрдВ:
learn.lr_find() learn.recorder.plot(skip_start=0)

рд╣рдо рдкрд░рддреЛрдВ рдХреЗ рдХреНрд░рдорд┐рдХ рд╡рд┐рдЧрд▓рди рдХреЗ рд╕рд╛рде рдореЙрдбрд▓ рдХреЛ рдкреНрд░рд╢рд┐рдХреНрд╖рд┐рдд рдХрд░рддреЗ рд╣реИрдВред
learn.fit_one_cycle(2, 2e-2, moms=(0.8,0.7))

learn.freeze_to(-2) learn.fit_one_cycle(3, slice(1e-2/(2.6**4),1e-2), moms=(0.8,0.7))

learn.freeze_to(-3) learn.fit_one_cycle(2, slice(5e-3/(2.6**4),5e-3), moms=(0.8,0.7))

learn.unfreeze() learn.fit_one_cycle(2, slice(1e-3/(2.6**4),1e-3), moms=(0.8,0.7))

learn.save('tweet-0801')
рд╣рдо рджреЗрдЦрддреЗ рд╣реИрдВ рдХрд┐ рд╕рддреНрдпрд╛рдкрди рдХреЗ рдирдореВрдиреЗ рдкрд░ рдЙрдиреНрд╣реЛрдВрдиреЗ рд╕рдЯреАрдХрддрд╛ = 80.1% рд╣рд╛рд╕рд┐рд▓ рдХреАред
рд╣рдо рдЕрдкрдиреЗ рдкрд┐рдЫрд▓реЗ рд▓реЗрдЦ рдкрд░
рдЬрд╝реНрд▓реЛрдбреАрдмрд▓ рдЯрд┐рдкреНрдкрдгреА рдкрд░ рдореЙрдбрд▓ рдХрд╛
рдкрд░реАрдХреНрд╖рдг рдХрд░реЗрдВрдЧреЗ :
learn.predict(' тАФ ?')
Out: (Category 0, tensor(0), tensor([0.6283, 0.3717]))
рд╣рдо рджреЗрдЦрддреЗ рд╣реИрдВ рдХрд┐ рдореЙрдбрд▓ рдиреЗ рдЗрд╕ рдЯрд┐рдкреНрдкрдгреА рдХреЛ рдирдХрд╛рд░рд╛рддреНрдордХ :-) рдХреЗ рд▓рд┐рдП рдЬрд┐рдореНрдореЗрджрд╛рд░ рдард╣рд░рд╛рдпрд╛ рд╣реИ
рдкрд░реАрдХреНрд╖рдг рдирдореВрдиреЗ рдкрд░ рдореЙрдбрд▓ рдХреА рдЬрд╛рдБрдЪ рдХрд░рдирд╛
рдЗрд╕ рд╕реНрддрд░ рдкрд░ рдореБрдЦреНрдп рдХрд╛рд░реНрдп рд╕рд╛рдорд╛рдиреНрдпреАрдХрд░рдг рдХреНрд╖рдорддрд╛ рдХреЗ рд▓рд┐рдП рдореЙрдбрд▓ рдХрд╛ рдкрд░реАрдХреНрд╖рдг рдХрд░рдирд╛ рд╣реИред рдРрд╕рд╛ рдХрд░рдиреЗ рдХреЗ рд▓рд┐рдП, рд╣рдо DataFrame df_test рдореЗрдВ рд╕рдВрдЧреНрд░рд╣реАрдд рдбреЗрдЯрд╛рд╕реЗрдЯ рдкрд░ рдореЙрдбрд▓ рдХреЛ рдорд╛рдиреНрдп рдХрд░рддреЗ рд╣реИрдВ, рдЬреЛ рддрдм рддрдХ рднрд╛рд╖рд╛ рдореЙрдбрд▓ рдпрд╛ рдХреНрд▓рд╛рд╕рд┐рдлрд╛рдпрд░рд┐рдпрд░ рдХреЗ рд▓рд┐рдП рдЙрдкрд▓рдмреНрдз рдирд╣реАрдВ рдерд╛ред
data_test_clas = TextClasDataBunch.from_df(path, vocab=data_lm.train_ds.vocab, bs=32, train_df=df_train, valid_df=df_test, text_cols=0, label_cols=1, tokenizer=tokenizer)
config = awd_lstm_clas_config.copy() config['n_hid'] = 1150 learn_test = text_classifier_learner(data_test_clas, AWD_LSTM, config=config, drop_mult=0.5)
learn_test.load_encoder('ft_enc') learn_test.load('tweet-0801')
learn_test.validate()
Out: [0.4391682, tensor(0.7973)]
рд╣рдо рджреЗрдЦрддреЗ рд╣реИрдВ рдХрд┐ рдкрд░реАрдХреНрд╖рдг рдирдореВрдиреЗ рдкрд░ рд╕рдЯреАрдХрддрд╛ 79.7% рдереАред
рднреНрд░рдо рдореИрдЯреНрд░рд┐рдХреНрд╕ рдкрд░ рдПрдХ рдирдЬрд╝рд░ рдбрд╛рд▓реЗрдВ:
interp = ClassificationInterpretation.from_learner(learn) interp.plot_confusion_matrix()

рд╣рдо рд╕рдЯреАрдХ, рд░рд┐рдХреЙрд▓ рдФрд░ рдПрдл 1 рд╕реНрдХреЛрд░ рдорд╛рдкрджрдВрдбреЛрдВ рдХреА рдЧрдгрдирд╛ рдХрд░рддреЗ рд╣реИрдВред
neg_precision = interp.confusion_matrix()[0][0] / (interp.confusion_matrix()[0][0] + interp.confusion_matrix()[1][0]) neg_recall = interp.confusion_matrix()[0][0] / (interp.confusion_matrix()[0][0] + interp.confusion_matrix()[0][1]) pos_precision = interp.confusion_matrix()[1][1] / (interp.confusion_matrix()[1][1] + interp.confusion_matrix()[0][1]) pos_recall = interp.confusion_matrix()[1][1] / (interp.confusion_matrix()[1][1] + interp.confusion_matrix()[1][0]) neg_f1score = 2 * (neg_precision * neg_recall) / (neg_precision + neg_recall) pos_f1score = 2 * (pos_precision * pos_recall) / (pos_precision + pos_recall)
print(' F1-score') print(' Negative {0:1.5f} {1:1.5f} {2:1.5f}'.format(neg_precision, neg_recall, neg_f1score)) print(' Positive {0:1.5f} {1:1.5f} {2:1.5f}'.format(pos_precision, pos_recall, pos_f1score)) print(' Average {0:1.5f} {1:1.5f} {2:1.5f}'.format(statistics.mean([neg_precision, pos_precision]), statistics.mean([neg_recall, pos_recall]), statistics.mean([neg_f1score, pos_f1score])))
Out: F1-score Negative 0.79989 0.80451 0.80219 Positive 0.80142 0.79675 0.79908 Average 0.80066 0.80063 0.80064
рдкрд░реАрдХреНрд╖рдг рдирдореВрдирд╛ рдФрд╕рдд F1-рд╕реНрдХреЛрд░ = 0.80064 рдореЗрдВ рджрд┐рдЦрд╛рдпрд╛ рдЧрдпрд╛ рдкрд░рд┐рдгрд╛рдоред
рд╕рд╣реЗрдЬреЗ рдЧрдП рдореЙрдбрд▓ рд╡рдЬрди рдХреЛ
рдпрд╣рд╛рдВ рд▓реЗ рдЬрд╛рдпрд╛ рдЬрд╛ рд╕рдХрддрд╛
рд╣реИ ред