рд╣рдо Fast.ai рдХрд╛ рдЙрдкрдпреЛрдЧ рдХрд░рдХреЗ рдЧреНрд░рдВрдереЛрдВ рдХреА рдЯреЛрди рдХрд╛ рд╡рд┐рд╢реНрд▓реЗрд╖рдг рдХрд░рддреЗ рд╣реИрдВ

рд▓реЗрдЦ рд░реВрд╕реА рдореЗрдВ рдкрд╛рда рд╕рдВрджреЗрд╢реЛрдВ рдХреЗ рдЯрди рдХреА рд╡рд░реНрдЧреАрдХрд░рдг рдкрд░ рдЪрд░реНрдЪрд╛ рдХрд░реЗрдЧрд╛ (рдФрд░ рдЕрдирд┐рд╡рд╛рд░реНрдп рд░реВрдк рд╕реЗ рдЙрд╕реА рддрдХрдиреАрдХ рдХрд╛ рдЙрдкрдпреЛрдЧ рдХрд░рдХреЗ рдЧреНрд░рдВрдереЛрдВ рдХреЗ рдХрд┐рд╕реА рднреА рд╡рд░реНрдЧреАрдХрд░рдг)ред рд╣рдо рдЗрд╕ рд▓реЗрдЦ рдХреЛ рдПрдХ рдЖрдзрд╛рд░ рдХреЗ рд░реВрдк рдореЗрдВ рд▓реЗрдВрдЧреЗ, рдЬрд┐рд╕рдореЗрдВ Word2vec рдореЙрдбрд▓ рдХрд╛ рдЙрдкрдпреЛрдЧ рдХрд░рдХреЗ CNN рдЖрд░реНрдХрд┐рдЯреЗрдХреНрдЪрд░ рдкрд░ рдЖрдЬ рдХреА рд░рд╛рдд рдХреЗ рд╡рд░реНрдЧреАрдХрд░рдг рдкрд░ рд╡рд┐рдЪрд╛рд░ рдХрд┐рдпрд╛ рдЧрдпрд╛ рдерд╛ред рд╣рдорд╛рд░реЗ рдЙрджрд╛рд╣рд░рдг рдореЗрдВ, рд╣рдо ULMFit рдореЙрдбрд▓ рдХрд╛ рдЙрдкрдпреЛрдЧ рдХрд░рдХреЗ рдПрдХ рд╣реА рдбреЗрдЯрд╛рд╕реЗрдЯ рдкрд░ рд╕рдХрд╛рд░рд╛рддреНрдордХ рдФрд░ рдирдХрд╛рд░рд╛рддреНрдордХ рдореЗрдВ рдЯреНрд╡реАрдЯреНрд╕ рдХреЛ рдЕрд▓рдЧ рдХрд░рдиреЗ рдХреА рдПрдХ рд╣реА рд╕рдорд╕реНрдпрд╛ рдХреЛ рд╣рд▓ рдХрд░реЗрдВрдЧреЗред рд▓реЗрдЦ рд╕реЗ рдкрд░рд┐рдгрд╛рдо (рдФрд╕рдд рдПрдл 1-рд╕реНрдХреЛрд░ = 0.78142) рдХреЛ рдЖрдзрд╛рд░ рд░реЗрдЦрд╛ рдХреЗ рд░реВрдк рдореЗрдВ рд╕реНрд╡реАрдХрд╛рд░ рдХрд┐рдпрд╛ рдЬрд╛рдПрдЧрд╛ред

рдкрд░рд┐рдЪрдп


ULMFIT рдореЙрдбрд▓ рдХреЛ 2018 рдореЗрдВ fast.ai Developers (рдЬреЗрд░реЗрдореА рд╣реЙрд╡рд░реНрдб, рд╕реЗрдмреЗрд╕реНрдЯрд┐рдпрди рд░реБрдбрд░) рджреНрд╡рд╛рд░рд╛ рдкреЗрд╢ рдХрд┐рдпрд╛ рдЧрдпрд╛ рдерд╛ред рджреГрд╖реНрдЯрд┐рдХреЛрдг рдХрд╛ рд╕рд╛рд░ рдПрдирдПрд▓рдкреА рдХрд╛рд░реНрдпреЛрдВ рдореЗрдВ рдЯреНрд░рд╛рдВрд╕рдлрд░ рд▓рд░реНрдирд┐рдВрдЧ рдХрд╛ рдЙрдкрдпреЛрдЧ рдХрд░рдирд╛ рд╣реИ рдЬрдм рдЖрдк рдкреВрд░реНрд╡-рдкреНрд░рд╢рд┐рдХреНрд╖рд┐рдд рдореЙрдбрд▓ рдХрд╛ рдЙрдкрдпреЛрдЧ рдХрд░рддреЗ рд╣реИрдВ, рдЕрдкрдиреЗ рдореЙрдбрд▓ рдХреЗ рдкреНрд░рд╢рд┐рдХреНрд╖рдг рдХреЗ рд▓рд┐рдП рд╕рдордп рдХрдо рдХрд░рддреЗ рд╣реИрдВ рдФрд░ рд▓реЗрдмрд▓ рдХрд┐рдП рдЧрдП рдкрд░реАрдХреНрд╖рдг рдирдореВрдиреЗ рдХреЗ рдЖрдХрд╛рд░ рдХреЗ рд▓рд┐рдП рдЖрд╡рд╢реНрдпрдХрддрд╛рдУрдВ рдХреЛ рдХрдо рдХрд░рддреЗ рд╣реИрдВред

рд╣рдорд╛рд░реЗ рдорд╛рдорд▓реЗ рдореЗрдВ рдкреНрд░рд╢рд┐рдХреНрд╖рдг рдпреЛрдЬрдирд╛ рдЗрд╕ рдкреНрд░рдХрд╛рд░ рд╣реЛрдЧреА:



рднрд╛рд╖рд╛ рдореЙрдбрд▓ рдХрд╛ рдЕрд░реНрде рдЕрдиреБрдХреНрд░рдо рдореЗрдВ рдЕрдЧрд▓реЗ рд╢рдмреНрдж рдХреА рднрд╡рд┐рд╖реНрдпрд╡рд╛рдгреА рдХрд░рдиреЗ рдореЗрдВ рд╕рдХреНрд╖рдо рд╣реЛрдирд╛ рд╣реИред рдЗрд╕ рддрд░рд╣ рд╕реЗ рд▓рдВрдмреЗ рд╕рдордп рддрдХ рдЬреБрдбрд╝реЗ рд╣реБрдП рдЧреНрд░рдВрдереЛрдВ рдХреЛ рдкреНрд░рд╛рдкреНрдд рдХрд░рдирд╛ рд╕рдорд╕реНрдпрд╛рдЧреНрд░рд╕реНрдд рд╣реИ, рд▓реЗрдХрд┐рди рдлрд┐рд░ рднреА, рднрд╛рд╖рд╛ рдореЙрдбрд▓ рднрд╛рд╖рд╛ рдХреЗ рдЧреБрдгреЛрдВ рдХреЛ рдкрдХрдбрд╝рдиреЗ рдореЗрдВ рд╕рдХреНрд╖рдо рд╣реИрдВ, рд╢рдмреНрджреЛрдВ рдХреЗ рдЙрдкрдпреЛрдЧ рдХреЗ рд╕рдВрджрд░реНрдн рдХреЛ рд╕рдордЭрддреЗ рд╣реИрдВ, рдЗрд╕рд▓рд┐рдП рдпрд╣ рднрд╛рд╖рд╛ рдореЙрдбрд▓ рд╣реИ (рдФрд░ рдирд╣реАрдВ, рдЙрджрд╛рд╣рд░рдг рдХреЗ рд▓рд┐рдП, рд╢рдмреНрджреЛрдВ рдХрд╛ рд╡реЗрдХреНрдЯрд░ рдкреНрд░рджрд░реНрд╢рди) рдЬреЛ рдХрд┐ рдкреНрд░реМрджреНрдпреЛрдЧрд┐рдХреА рдХрд╛ рдЖрдзрд╛рд░ рд╣реИред рднрд╛рд╖рд╛ рдХреЗ рдореЙрдбрд▓рд┐рдВрдЧ рдХреЗ рдХрд╛рд░реНрдп рдХреЗ рд▓рд┐рдП, ULMFit AWD-LSTM рдЖрд░реНрдХрд┐рдЯреЗрдХреНрдЪрд░ рдХрд╛ рдЙрдкрдпреЛрдЧ рдХрд░рддрд╛ рд╣реИ, рдЬрд┐рд╕рдореЗрдВ рдЬрд╣рд╛рдВ рднреА рд╕рдВрднрд╡ рд╣реЛ, рдбреНрд░реЙрдкрдЖрдЙрдЯ рдХрд╛ рд╕рдХреНрд░рд┐рдп рдЙрдкрдпреЛрдЧ рд╢рд╛рдорд┐рд▓ рд╣реИ рдФрд░ рд╕рдордЭ рдореЗрдВ рдЖрддрд╛ рд╣реИред рднрд╛рд╖рд╛ рдореЙрдбрд▓ рдкреНрд░рд╢рд┐рдХреНрд╖рдг рдХреЗ рдкреНрд░рдХрд╛рд░ рдХреЛ рдХрднреА-рдХрднреА рдЕрд░реНрдз-рдкрд░реНрдпрд╡реЗрдХреНрд╖рдгреАрдп рд╢рд┐рдХреНрд╖рд╛ рдХрд╣рд╛ рдЬрд╛рддрд╛ рд╣реИ, рдХреНрдпреЛрдВрдХрд┐ рдпрд╣рд╛рдВ рд▓реЗрдмрд▓ рдЕрдЧрд▓реЗ рд╢рдмреНрдж рд╣реИ рдФрд░ рдХреБрдЫ рднреА рдЖрдкрдХреЗ рд╣рд╛рдереЛрдВ рд╕реЗ рдЪрд┐рд╣реНрдирд┐рдд рдХрд░рдиреЗ рдХреА рдЖрд╡рд╢реНрдпрдХрддрд╛ рдирд╣реАрдВ рд╣реИред

рдкреВрд░реНрд╡-рдкреНрд░рд╢рд┐рдХреНрд╖рд┐рдд рднрд╛рд╖рд╛ рдореЙрдбрд▓ рдХреЗ рд░реВрдк рдореЗрдВ, рд╣рдо рд▓рдЧрднрдЧ рдПрдХрдорд╛рддреНрд░ рдЙрдкрд▓рдмреНрдз рд╕рд╛рд░реНрд╡рдЬрдирд┐рдХ рд░реВрдк рд╕реЗ рдЙрдкрдпреЛрдЧ рдХрд░реЗрдВрдЧреЗред
рдЖрдЗрдП рд╢реБрд░реВ рд╕реЗ рд╣реА рд╕реАрдЦрдиреЗ рдХреЗ рдПрд▓реНрдЧреЛрд░рд┐рдереНрдо рдХреЗ рдорд╛рдзреНрдпрдо рд╕реЗ рдЪрд▓рддреЗ рд╣реИрдВред

рд╣рдо рдкреБрд╕реНрддрдХрд╛рд▓рдпреЛрдВ рдХреЛ рд▓реЛрдб рдХрд░рддреЗ рд╣реИрдВ (рд╣рдо рдХрд┐рд╕реА рднреА рдЕрд╕рдВрдЧрддрддрд╛ рдХреЗ рдорд╛рдорд▓реЗ рдореЗрдВ Fast.ai рдХреЗ рд╕рдВрд╕реНрдХрд░рдг рдХреА рдЬрд╛рдВрдЪ рдХрд░рддреЗ рд╣реИрдВ):

%load_ext autoreload %autoreload 2 import pandas as pd import numpy as np import re import statistics import fastai print('fast.ai version is:', fastai.__version__) from fastai import * from fastai.text import * from sklearn.model_selection import train_test_split path = '' 

 Out: fast.ai version is: 1.0.58 

рд╣рдо рдкреНрд░рд╢рд┐рдХреНрд╖рдг рдХреЗ рд▓рд┐рдП рдбреЗрдЯрд╛ рддреИрдпрд╛рд░ рдХрд░рддреЗ рд╣реИрдВ


рд╕рд╛рджреГрд╢реНрдп рд╕реЗ, рд╣рдо рдпреВрд▓рд┐рдпрд╛ рд░реБрдмрддрд╕реЛрд╡рд╛ рджреНрд╡рд╛рд░рд╛ рд▓рдШреБ рдЧреНрд░рдВрде RuTweetCorp рдХреЗ рд╢рд░реАрд░ рдкрд░ рдкреНрд░рд╢рд┐рдХреНрд╖рдг рдХрд╛ рдЖрдпреЛрдЬрди рдХрд░реЗрдВрдЧреЗ, рдЬреЛ рдЯреНрд╡рд┐рдЯрд░ рд╕реЗ рд░реВрд╕реА-рднрд╛рд╖рд╛ рдХреЗ рд╕рдВрджреЗрд╢реЛрдВ рдХреЗ рдЖрдзрд╛рд░ рдкрд░ рдмрдирд╛рдпрд╛ рдЧрдпрд╛ рд╣реИред рд╢рд░реАрд░ рдореЗрдВ 114,991 рд╕рдХрд╛рд░рд╛рддреНрдордХ рдЯреНрд╡реАрдЯ рдФрд░ CSV рдкреНрд░рд╛рд░реВрдк рдореЗрдВ 111,923 рдирдХрд╛рд░рд╛рддреНрдордХ рдЯреНрд╡реАрдЯ рд╣реИрдВред рдЗрд╕рдХреЗ рдЕрд▓рд╛рд╡рд╛, SQL рдкреНрд░рд╛рд░реВрдк рдореЗрдВ 17 639 674 рд░рд┐рдХреЙрд░реНрдб рдХреА рдорд╛рддреНрд░рд╛ рдХреЗ рд╕рд╛рде рдЕрд╕рдВрдмрджреНрдз рдЯреНрд╡реАрдЯреНрд╕ рдХрд╛ рдПрдХ рдбреЗрдЯрд╛рдмреЗрд╕ рд╣реИред рд╣рдорд╛рд░реЗ рд╡рд░реНрдЧреАрдХрд░рдг рдХрд╛ рдХрд╛рд░реНрдп рдпрд╣ рдирд┐рд░реНрдзрд╛рд░рд┐рдд рдХрд░рдирд╛ рд╣реЛрдЧрд╛ рдХрд┐ рдЯреНрд╡реАрдЯ рд╕рдХрд╛рд░рд╛рддреНрдордХ рд╣реИ рдпрд╛ рдирдХрд╛рд░рд╛рддреНрдордХред

рдЪреВрдБрдХрд┐ 17 рдорд┐рд▓рд┐рдпрди рдЯреНрд╡реАрдЯреНрд╕ рдкрд░ рднрд╛рд╖рд╛ рдореЙрдбрд▓ рдХреЛ рдлрд┐рд░ рд╕реЗ рд▓рд┐рдЦрдирд╛ рдерд╛ рдФрд░ рдЯреНрд░рд╛рдВрд╕рдлрд░ рд▓рд░реНрдирд┐рдВрдЧ рдХреЛ рджрд┐рдЦрд╛рдиреЗ рдХреЗ рд▓рд┐рдП рдХрд╛рд░реНрдп рдХрд░рдирд╛ рдерд╛, рдЗрд╕рд▓рд┐рдП рд╣рдо рдкреНрд░рд╢рд┐рдХреНрд╖рдг рдбреЗрдЯрд╛рд╕реЗрдЯ рд╕реЗ рдкрд╛рда рдХреЗ рдПрдХ рдЯреБрдХрдбрд╝реЗ рдкрд░ рднрд╛рд╖рд╛ рдореЙрдбрд▓ рдХреЛ рдлрд┐рд░ рд╕реЗ рд▓рд┐рдЦрдирд╛ рдЪрд╛рд╣рддреЗ рд╣реИрдВ, рдкреВрд░реА рддрд░рд╣ рд╕реЗ рдЕрд╕рдВрдмрджреНрдз рдЯреНрд╡реАрдЯ рдХреЗ рдЖрдзрд╛рд░ рдХреЛ рдЕрдирджреЗрдЦрд╛ рдХрд░ рд░рд╣реЗ рд╣реИрдВред рд╕рдВрднрд╡рддрдГ, рднрд╛рд╖рд╛ рдореЙрдбрд▓ рдХреЛ "рддреЗрдЬ" рдХрд░рдиреЗ рдХреЗ рд▓рд┐рдП рдЗрд╕ рдЖрдзрд╛рд░ рдХрд╛ рдЙрдкрдпреЛрдЧ рдХрд░рдХреЗ, рдЖрдк рд╕рдордЧреНрд░ рдкрд░рд┐рдгрд╛рдо рдореЗрдВ рд╕реБрдзрд╛рд░ рдХрд░ рд╕рдХрддреЗ рд╣реИрдВред

рд╣рдо рдкреНрд░рд╛рд░рдВрднрд┐рдХ рд╢рдмреНрдж рдкреНрд░рд╕рдВрд╕реНрдХрд░рдг рдХреЗ рд╕рд╛рде рдкреНрд░рд╢рд┐рдХреНрд╖рдг рдФрд░ рдкрд░реАрдХреНрд╖рдг рдХреЗ рд▓рд┐рдП рдбреЗрдЯрд╛рд╕реЗрдЯ рдмрдирд╛рддреЗ рд╣реИрдВред рд╣рдо рдореВрд▓ рд▓реЗрдЦ рд╕реЗ рдХреЛрдб рд▓реЗрддреЗ рд╣реИрдВ:

 #   n = ['id', 'date', 'name', 'text', 'typr', 'rep', 'rtw', 'faw', 'stcount', 'foll', 'frien', 'listcount'] data_positive = pd.read_csv('data/positive.csv', sep=';', error_bad_lines=False, names=n, usecols=['text']) data_negative = pd.read_csv('data/negative.csv', sep=';', error_bad_lines=False, names=n, usecols=['text']) #    sample_size = min(data_positive.shape[0], data_negative.shape[0]) raw_data = np.concatenate((data_positive['text'].values[:sample_size], data_negative['text'].values[:sample_size]), axis=0) labels = [1] * sample_size + [0] * sample_size 

 def preprocess_text(text): text = text.lower().replace("", "") text = re.sub('((www\.[^\s]+)|(https?://[^\s]+))', 'URL', text) text = re.sub('@[^\s]+', 'USER', text) text = re.sub('[^a-zA-Z--1-9]+', ' ', text) text = re.sub(' +', ' ', text) return text.strip() data = [preprocess_text(t) for t in raw_data] 

 df_train=pd.DataFrame(columns=['Text', 'Label']) df_test=pd.DataFrame(columns=['Text', 'Label']) df_train['Text'], df_test['Text'], df_train['Label'], df_test['Label'] = train_test_split(data, labels, test_size=0.2, random_state=1) 

 df_val=pd.DataFrame(columns=['Text', 'Label']) df_train, df_val = train_test_split(df_train, test_size=0.2, random_state=1) 

рд╣рдо рджреЗрдЦрддреЗ рд╣реИрдВ рдХрд┐ рдХреНрдпрд╛ рд╣реБрдЖ:

 df_train.groupby('Label').count() 



 df_val.groupby('Label').count() 



 df_test.groupby('Label').count() 



рдПрдХ рднрд╛рд╖рд╛ рдореЙрдбрд▓ рд╕реАрдЦрдирд╛


рдбреЗрдЯрд╛ рд▓реЛрдб рд╣реЛ рд░рд╣рд╛ рд╣реИ:

 tokenizer=Tokenizer(lang='xx') data_lm = TextLMDataBunch.from_df(path, tokenizer=tokenizer, bs=16, train_df=df_train, valid_df=df_val, text_cols=0) 

рд╣рдо рд╕рд╛рдордЧреНрд░реА рдХреЛ рджреЗрдЦрддреЗ рд╣реИрдВ:

 data_lm.show_batch() 



рд╣рдо рдкреВрд░реНрд╡ рдкреНрд░рд╢рд┐рдХреНрд╖рд┐рдд рдореЙрдбрд▓ рдФрд░ рдПрдХ рд╢рдмреНрджрдХреЛрд╢ рдХреЗ рд╕рдВрдЧреНрд░рд╣реАрдд рднрд╛рд░ рдХреЛ рд▓рд┐рдВрдХ рдкреНрд░рджрд╛рди рдХрд░рддреЗ рд╣реИрдВ:

 weights_pretrained = 'ULMFit/lm_5_ep_lr2-3_5_stlr' itos_pretrained = 'ULMFit/itos' pretained_data = (weights_pretrained, itos_pretrained) 

рд╣рдо рд╢рд┐рдХреНрд╖рд╛рд░реНрдереА рдмрдирд╛рддреЗ рд╣реИрдВ, рд▓реЗрдХрд┐рди рдЙрд╕рд╕реЗ рдкрд╣рд▓реЗ - рдЙрдкрд╡рд╛рд╕ рдХреЗ рд▓рд┐рдП рдПрдХ рдмреИрд╕рд╛рдЦреАред рдкреВрд░реНрд╡-рдкреНрд░рд╢рд┐рдХреНрд╖рд┐рдд рдореЙрдбрд▓ рдкреБрд╕реНрддрдХрд╛рд▓рдп рдХреЗ рдПрдХ рдкреБрд░рд╛рдиреЗ рд╕рдВрд╕реНрдХрд░рдг рдкрд░ рдкреНрд░рд╢рд┐рдХреНрд╖рд┐рдд рдХрд┐рдпрд╛ рдЧрдпрд╛ рдерд╛, рдЗрд╕рд▓рд┐рдП рдЖрдкрдХреЛ рддрдВрддреНрд░рд┐рдХрд╛ рдиреЗрдЯрд╡рд░реНрдХ рдХреА рдЫрд┐рдкреА рд╣реБрдИ рдкрд░рдд рдореЗрдВ рдиреЛрдбреНрд╕ рдХреА рд╕рдВрдЦреНрдпрд╛ рдХреЛ рд╕рдорд╛рдпреЛрдЬрд┐рдд рдХрд░рдиреЗ рдХреА рдЖрд╡рд╢реНрдпрдХрддрд╛ рд╣реИред

 config = awd_lstm_lm_config.copy() config['n_hid'] = 1150 learn_lm = language_model_learner(data_lm, AWD_LSTM, config=config, pretrained_fnames=pretained_data, drop_mult=0.3) learn_lm.freeze() 

рд╣рдо рдЗрд╖реНрдЯрддрдо рд╕реАрдЦрдиреЗ рдХреА рджрд░ рдХреА рддрд▓рд╛рд╢ рдХрд░ рд░рд╣реЗ рд╣реИрдВ:

 learn_lm.lr_find() learn_lm.recorder.plot() 


рд╣рдо рддреАрд╕рд░реЗ рдпреБрдЧ рдХреЗ рдореЙрдбрд▓ рдХреЛ рдкреНрд░рд╢рд┐рдХреНрд╖рд┐рдд рдХрд░рддреЗ рд╣реИрдВ (рдореЙрдбрд▓ рдореЗрдВ, рдХреЗрд╡рд▓ рдкрд░рддреЛрдВ рдХрд╛ рдЕрдВрддрд┐рдо рд╕рдореВрд╣ рдЕрдкрд░рд┐рд╡рд░реНрддрдиреАрдп рд╣реИ)ред

 learn_lm.fit_one_cycle(3, 1e-2, moms=(0.8, 0.7)) 


рдореЙрдбрд▓ рдХреЛ рдкрд░рд┐рднрд╛рд╖рд┐рдд рдХрд░рдирд╛, рдХрдо рд╕реАрдЦрдиреЗ рдХреА рджрд░ рдХреЗ рд╕рд╛рде 5 рдФрд░ рдпреБрдЧреЛрдВ рдХреЛ рдкрдврд╝рд╛рдирд╛:

 learn_lm.unfreeze() learn_lm.fit_one_cycle(5, 1e-3, moms=(0.8, 0.7)) 



 learn_lm.save('lm_ft') 

рд╣рдо рдПрдХ рдкреНрд░рд╢рд┐рдХреНрд╖рд┐рдд рдореЙрдбрд▓ рдкрд░ рдкрд╛рда рдЙрддреНрдкрдиреНрди рдХрд░рдиреЗ рдХрд╛ рдкреНрд░рдпрд╛рд╕ рдХрд░рддреЗ рд╣реИрдВред

 learn_lm.predict("  ", n_words=5) 

 Out: '       ' 

 learn_lm.predict(",  ", n_words=4) 

 Out: ',      ' 

рд╣рдо рджреЗрдЦрддреЗ рд╣реИрдВ - рдХреБрдЫ рдРрд╕рд╛ рдЬреЛ рдореЙрдбрд▓ рдХрд░рддрд╛ рд╣реИред рд▓реЗрдХрд┐рди рд╣рдорд╛рд░рд╛ рдореБрдЦреНрдп рдХрд╛рд░реНрдп рд╡рд░реНрдЧреАрдХрд░рдг рд╣реИ, рдФрд░ рдЗрд╕рдХреЗ рд╕рдорд╛рдзрд╛рди рдХреЗ рд▓рд┐рдП рд╣рдо рдореЙрдбрд▓ рд╕реЗ рдПрдХ рдПрдирдХреЛрдбрд░ рд▓реЗрдВрдЧреЗред

 learn_lm.save_encoder('ft_enc') 

рд╣рдо рдХреНрд▓рд╛рд╕рд┐рдлрд╛рдпрд░ рдЯреНрд░реЗрди рдХрд░рддреЗ рд╣реИрдВ


рдкреНрд░рд╢рд┐рдХреНрд╖рдг рдХреЗ рд▓рд┐рдП рдбреЗрдЯрд╛ рдбрд╛рдЙрдирд▓реЛрдб рдХрд░реЗрдВ

 data_clas = TextClasDataBunch.from_df(path, vocab=data_lm.train_ds.vocab, bs=32, train_df=df_train, valid_df=df_val, text_cols=0, label_cols=1, tokenizer=tokenizer) 

рдЖрдЗрдП рдбреЗрдЯрд╛ рдХреЛ рджреЗрдЦреЗрдВ, рд╣рдо рджреЗрдЦрддреЗ рд╣реИрдВ рдХрд┐ рд▓реЗрдмрд▓ рд╕рдлрд▓рддрд╛рдкреВрд░реНрд╡рдХ рдЧрд┐рдиреЗ рдЧрдП рдереЗ (0 рдХрд╛ рдорддрд▓рдм рдирдХрд╛рд░рд╛рддреНрдордХ рд╣реИ, рдФрд░ 1 рдХрд╛ рдорддрд▓рдм рд╕рдХрд╛рд░рд╛рддреНрдордХ рдЯрд┐рдкреНрдкрдгреА рд╣реИ):

 data_clas.show_batch() 



рдПрдХ рд╕рдорд╛рди рдмреИрд╕рд╛рдЦреА рдХреЗ рд╕рд╛рде рдПрдХ рд╢рд┐рдХреНрд╖рд╛рд░реНрдереА рдмрдирд╛рдПрдБ:

 config = awd_lstm_clas_config.copy() config['n_hid'] = 1150 learn = text_classifier_learner(data_clas, AWD_LSTM, config=config, drop_mult=0.5) 

рд╣рдо рдкрд┐рдЫрд▓реЗ рдЪрд░рдг рдореЗрдВ рдкреНрд░рд╢рд┐рдХреНрд╖рд┐рдд рдПрдирдХреЛрдбрд░ рдХреЛ рд▓реЛрдб рдХрд░рддреЗ рд╣реИрдВ рдФрд░ рдореЙрдбрд▓ рдХреЛ рднрд╛рд░ рдХреЗ рдЕрдВрддрд┐рдо рд╕рдореВрд╣ рдХреЛ рдЫреЛрдбрд╝рдХрд░ рдлреНрд░реАрдЬ рдХрд░рддреЗ рд╣реИрдВ:
 learn.load_encoder('ft_enc') learn.freeze() 

рд╣рдо рдЗрд╖реНрдЯрддрдо рд╕реАрдЦрдиреЗ рдХреА рджрд░ рдХреА рддрд▓рд╛рд╢ рдХрд░ рд░рд╣реЗ рд╣реИрдВ:

 learn.lr_find() learn.recorder.plot(skip_start=0) 


рд╣рдо рдкрд░рддреЛрдВ рдХреЗ рдХреНрд░рдорд┐рдХ рд╡рд┐рдЧрд▓рди рдХреЗ рд╕рд╛рде рдореЙрдбрд▓ рдХреЛ рдкреНрд░рд╢рд┐рдХреНрд╖рд┐рдд рдХрд░рддреЗ рд╣реИрдВред

 learn.fit_one_cycle(2, 2e-2, moms=(0.8,0.7)) 



 learn.freeze_to(-2) learn.fit_one_cycle(3, slice(1e-2/(2.6**4),1e-2), moms=(0.8,0.7)) 



 learn.freeze_to(-3) learn.fit_one_cycle(2, slice(5e-3/(2.6**4),5e-3), moms=(0.8,0.7)) 



 learn.unfreeze() learn.fit_one_cycle(2, slice(1e-3/(2.6**4),1e-3), moms=(0.8,0.7)) 



 learn.save('tweet-0801') 

рд╣рдо рджреЗрдЦрддреЗ рд╣реИрдВ рдХрд┐ рд╕рддреНрдпрд╛рдкрди рдХреЗ рдирдореВрдиреЗ рдкрд░ рдЙрдиреНрд╣реЛрдВрдиреЗ рд╕рдЯреАрдХрддрд╛ = 80.1% рд╣рд╛рд╕рд┐рд▓ рдХреАред

рд╣рдо рдЕрдкрдиреЗ рдкрд┐рдЫрд▓реЗ рд▓реЗрдЦ рдкрд░ рдЬрд╝реНрд▓реЛрдбреАрдмрд▓ рдЯрд┐рдкреНрдкрдгреА рдкрд░ рдореЙрдбрд▓ рдХрд╛ рдкрд░реАрдХреНрд╖рдг рдХрд░реЗрдВрдЧреЗ :

 learn.predict('        тАФ ?') 

 Out: (Category 0, tensor(0), tensor([0.6283, 0.3717])) 

рд╣рдо рджреЗрдЦрддреЗ рд╣реИрдВ рдХрд┐ рдореЙрдбрд▓ рдиреЗ рдЗрд╕ рдЯрд┐рдкреНрдкрдгреА рдХреЛ рдирдХрд╛рд░рд╛рддреНрдордХ :-) рдХреЗ рд▓рд┐рдП рдЬрд┐рдореНрдореЗрджрд╛рд░ рдард╣рд░рд╛рдпрд╛ рд╣реИ

рдкрд░реАрдХреНрд╖рдг рдирдореВрдиреЗ рдкрд░ рдореЙрдбрд▓ рдХреА рдЬрд╛рдБрдЪ рдХрд░рдирд╛


рдЗрд╕ рд╕реНрддрд░ рдкрд░ рдореБрдЦреНрдп рдХрд╛рд░реНрдп рд╕рд╛рдорд╛рдиреНрдпреАрдХрд░рдг рдХреНрд╖рдорддрд╛ рдХреЗ рд▓рд┐рдП рдореЙрдбрд▓ рдХрд╛ рдкрд░реАрдХреНрд╖рдг рдХрд░рдирд╛ рд╣реИред рдРрд╕рд╛ рдХрд░рдиреЗ рдХреЗ рд▓рд┐рдП, рд╣рдо DataFrame df_test рдореЗрдВ рд╕рдВрдЧреНрд░рд╣реАрдд рдбреЗрдЯрд╛рд╕реЗрдЯ рдкрд░ рдореЙрдбрд▓ рдХреЛ рдорд╛рдиреНрдп рдХрд░рддреЗ рд╣реИрдВ, рдЬреЛ рддрдм рддрдХ рднрд╛рд╖рд╛ рдореЙрдбрд▓ рдпрд╛ рдХреНрд▓рд╛рд╕рд┐рдлрд╛рдпрд░рд┐рдпрд░ рдХреЗ рд▓рд┐рдП рдЙрдкрд▓рдмреНрдз рдирд╣реАрдВ рдерд╛ред

 data_test_clas = TextClasDataBunch.from_df(path, vocab=data_lm.train_ds.vocab, bs=32, train_df=df_train, valid_df=df_test, text_cols=0, label_cols=1, tokenizer=tokenizer) 

 config = awd_lstm_clas_config.copy() config['n_hid'] = 1150 learn_test = text_classifier_learner(data_test_clas, AWD_LSTM, config=config, drop_mult=0.5) 

 learn_test.load_encoder('ft_enc') learn_test.load('tweet-0801') 

 learn_test.validate() 

 Out: [0.4391682, tensor(0.7973)] 

рд╣рдо рджреЗрдЦрддреЗ рд╣реИрдВ рдХрд┐ рдкрд░реАрдХреНрд╖рдг рдирдореВрдиреЗ рдкрд░ рд╕рдЯреАрдХрддрд╛ 79.7% рдереАред

рднреНрд░рдо рдореИрдЯреНрд░рд┐рдХреНрд╕ рдкрд░ рдПрдХ рдирдЬрд╝рд░ рдбрд╛рд▓реЗрдВ:

 interp = ClassificationInterpretation.from_learner(learn) interp.plot_confusion_matrix() 



рд╣рдо рд╕рдЯреАрдХ, рд░рд┐рдХреЙрд▓ рдФрд░ рдПрдл 1 рд╕реНрдХреЛрд░ рдорд╛рдкрджрдВрдбреЛрдВ рдХреА рдЧрдгрдирд╛ рдХрд░рддреЗ рд╣реИрдВред

 neg_precision = interp.confusion_matrix()[0][0] / (interp.confusion_matrix()[0][0] + interp.confusion_matrix()[1][0]) neg_recall = interp.confusion_matrix()[0][0] / (interp.confusion_matrix()[0][0] + interp.confusion_matrix()[0][1]) pos_precision = interp.confusion_matrix()[1][1] / (interp.confusion_matrix()[1][1] + interp.confusion_matrix()[0][1]) pos_recall = interp.confusion_matrix()[1][1] / (interp.confusion_matrix()[1][1] + interp.confusion_matrix()[1][0]) neg_f1score = 2 * (neg_precision * neg_recall) / (neg_precision + neg_recall) pos_f1score = 2 * (pos_precision * pos_recall) / (pos_precision + pos_recall) 

 print('    F1-score') print(' Negative {0:1.5f} {1:1.5f} {2:1.5f}'.format(neg_precision, neg_recall, neg_f1score)) print(' Positive {0:1.5f} {1:1.5f} {2:1.5f}'.format(pos_precision, pos_recall, pos_f1score)) print(' Average {0:1.5f} {1:1.5f} {2:1.5f}'.format(statistics.mean([neg_precision, pos_precision]), statistics.mean([neg_recall, pos_recall]), statistics.mean([neg_f1score, pos_f1score]))) 

 Out:     F1-score Negative 0.79989 0.80451 0.80219 Positive 0.80142 0.79675 0.79908 Average 0.80066 0.80063 0.80064 

рдкрд░реАрдХреНрд╖рдг рдирдореВрдирд╛ рдФрд╕рдд F1-рд╕реНрдХреЛрд░ = 0.80064 рдореЗрдВ рджрд┐рдЦрд╛рдпрд╛ рдЧрдпрд╛ рдкрд░рд┐рдгрд╛рдоред

рд╕рд╣реЗрдЬреЗ рдЧрдП рдореЙрдбрд▓ рд╡рдЬрди рдХреЛ рдпрд╣рд╛рдВ рд▓реЗ рдЬрд╛рдпрд╛ рдЬрд╛ рд╕рдХрддрд╛ рд╣реИ ред

Source: https://habr.com/ru/post/hi472988/


All Articles