Bagaimana saya mengajar jaringan saraf untuk mengimplementasikan fungsi penilaian posisi di Russian AI Cup CodeBall 2018

Memiliki kesempatan untuk menilai secara kualitatif situasi dalam game di beberapa titik waktu dan kemampuan untuk mensimulasikan dunia game saat membuat bot untuk salah satu solusi, tetap saja berusaha untuk melakukan tindakan seperti itu yang mengarah pada peningkatan penilaian ini dalam waktu dekat.

Fungsi estimasi posisi - mengembalikan nilai material yang kurang berarti lebih buruk. Atas input fungsi seperti itu, saya hanya mengirimkan vektor posisi dan kecepatan bola. Awalnya, fungsi ini diimplementasikan dengan formula yang cukup sederhana dan beberapa ifs. Namun, ini memberikan dasar yang baik untuk menyontek pada set log pelari lokal untuk pelatihan jaringan saraf berikutnya. Jadi saya menggulir 300 game (masing-masing 18.000 kutu) secara lokal, yang secara total menghasilkan sekitar 12GB log dan ditambah 145 log game teratas diunduh dari server (5.7GB).

Selanjutnya, perlu mengisolasi pelatihan dan menguji sampel dari log ini. Saya melakukan ini sebagai berikut: mulai dari gol saya melihat "masa lalu" untuk 300 ticks (5 detik) dan secara bertahap 5 ticks setiap posisi dan kecepatan bola + mengambil skor referensi sebagai contoh.

Poin penting: skor referensi (output) di sini dihitung oleh rumus

$$ display $$ O = S / exp (T / 60) $$ display $$

di mana S = -1 jika bola terbang ke gawang "saya" dan 1 sebaliknya, dan T adalah waktu yang tersisa dalam kutu sebelum gawang.

Poin lain yang kurang penting, tetapi juga penting: bidang permainan simetris dan, karenanya, skor referensi juga harus simetris terbalik jika dilihat dari sudut pandang lawan. Yaitu jika sesuatu dievaluasi dari sudut pandang "saya" sebagai X, maka posisi yang sama harus dievaluasi dari sudut pandang musuh sebagai -X. Ini berarti bahwa jika Anda "melipat dua" seluruh ruang input dari jaringan saraf dengan parameter apa pun, jaringan akan belajar lebih baik, secara relatif, "2 kali", dan yang paling penting, itu akan memberikan jawaban simetris yang dijamin kembali (yang, setidaknya, hanya indah). Saya "melipat" kecepatan bola di sepanjang sumbu Z. Sederhananya, jika bola terbang dari gawang "saya", maka saya melihat dari sudut pandang saya sendiri, jika tidak - dari sudut pandang lawan. Ternyata untuk jaringan saraf bola selalu terbang ke arah positif sepanjang Z. Hal yang sama dapat dilakukan untuk simetri longitudinal (sepanjang sumbu X), meskipun dalam kasus ini kami terus melihat dari sudut pandang tim yang sama, tetapi, seolah-olah, di cermin yang terletak di pesawat dengan normal (1, 0, 0).

Jadi, berikut adalah kode untuk menyiapkan sampel uji dan pelatihan dari log dengan Python:

import json from pprint import pprint import glob import numpy as np import random xtrain = [] ytrain = [] xtest = [] ytest = [] f1 = r"F:\Home\Projects\MailRuAI\Codeball2018\LocalRunner\logs_archive\logs_01/*.txt" f2 = r"F:\Home\Projects\MailRuAI\Codeball2018\LocalRunner\logs_archive\logs_02/*.txt" f3 = r"F:\Home\Projects\MailRuAI\Codeball2018\LocalRunner\logs_archive\logs_03/*.txt" f7 = r"F:\Home\Projects\MailRuAI\Codeball2018\downloaded_games/*.txt" for file in (glob.glob(f1) + glob.glob(f2) + glob.glob(f3) + glob.glob(f7)): with open(file) as f: content = f.readlines() print(len(content)) print(file) sumofscores = 0 lastscore0 = 0 lastscore1 = 0 ticksbackward = 300 ticksbackwardinc = 5 for x in range(0, len(content)): data = json.loads(content[x]) if "scores" in data and sum(data["scores"]) > sumofscores: sumofscores = sum(data["scores"]) value = 0 if data["scores"][0] > lastscore0: lastscore0 = data["scores"][0] value = 1 if data["scores"][1] > lastscore1: lastscore1 = data["scores"][1] value = -1 for y in range(ticksbackwardinc, ticksbackward, ticksbackwardinc): dataY = json.loads(content[x - y]) if "scores" in dataY and sum(dataY["scores"]) == sumofscores - 1: sign = 1 if dataY['ball']['velocity']['z'] < 0: sign = -1 signX = 1 if dataY['ball']['velocity']['x'] * sign < 0: signX = -1 inputs = np.zeros(6) inputs[0] = dataY['ball']['velocity']['x'] * sign * signX inputs[1] = dataY['ball']['velocity']['y'] inputs[2] = dataY['ball']['velocity']['z'] * sign inputs[3] = dataY['ball']['position']['x'] * sign * signX inputs[4] = dataY['ball']['position']['y'] inputs[5] = dataY['ball']['position']['z'] * sign outputs = np.zeros(2) outputs[0] = value*sign outputs[1] = y if (random.random() > 0.2): xtrain.append(inputs) ytrain.append(outputs) else: xtest.append(inputs) ytest.append(outputs) else: print("exceeded") print(len(xtrain)) print(len(xtest)) np.save("F:/Home/Projects/MailRuAI/Codeball2018/nnet/xtrain_BR.npy", np.asarray(xtrain)) np.save("F:/Home/Projects/MailRuAI/Codeball2018/nnet/ytrain_BR.npy", np.asarray(ytrain)) np.save("F:/Home/Projects/MailRuAI/Codeball2018/nnet/xtest_BR.npy", np.asarray(xtest)) np.save("F:/Home/Projects/MailRuAI/Codeball2018/nnet/ytest_BR.npy", np.asarray(ytest)) 

Yang paling penuh perhatian, mungkin, telah memperhatikan bahwa output mengandung dua output dan sama sekali tidak seperti yang saya jelaskan di atas, tetapi jangan khawatir, ini adalah kelainan dan transformasi mengikuti sebelum pelatihan itu sendiri:

 import numpy as np from keras.datasets import boston_housing from keras.models import Model, Sequential from keras.layers import Input, Dense, Concatenate, Add import random import datetime np.set_printoptions(edgeitems=50) xtrain = np.load("F:/Home/Projects/MailRuAI/Codeball2018/nnet/xtrain_BR.npy") ytrain = np.load("F:/Home/Projects/MailRuAI/Codeball2018/nnet/ytrain_BR.npy") xtest = np.load("F:/Home/Projects/MailRuAI/Codeball2018/nnet/xtest_BR.npy") ytest = np.load("F:/Home/Projects/MailRuAI/Codeball2018/nnet/ytest_BR.npy") ytrain = np.exp(-(ytrain[:,1])/60) * ytrain[:,0] ytest = np.exp(-(ytest[:,1])/60) * ytest[:,0] inp = Input(shape=(xtrain.shape[1],)) d1 = Dense(6, activation='relu')(inp) d2 = Dense(6, activation='linear')(inp) d3 = Dense(6, activation='sigmoid')(inp) added = Concatenate()([d1, d2, d3]) d21 = Dense(3, activation='relu')(added) d22 = Dense(3, activation='linear')(added) d23 = Dense(3, activation='sigmoid')(added) added2 = Concatenate()([d21, d22, d23]) d31 = Dense(3, activation='relu')(added2) d32 = Dense(3, activation='linear')(added2) d33 = Dense(3, activation='sigmoid')(added2) added3 = Concatenate()([d31, d32, d33]) out = Dense(1)(added3) model = Model(inputs=inp, outputs=out) model.compile(optimizer='adam', loss='mse', metrics=['mae']) #model.load_weights("F:/Home/Projects/MailRuAI/Codeball2018/nnet/WEXP_B36F.dat") for x in range(0, 10): lostTR, maeTR = model.evaluate(xtrain, ytrain, verbose=0) print("Train mae: " + repr(lostTR) + ", " + repr(maeTR)) lostTS, maeTS = model.evaluate(xtest, ytest, verbose=0) print("Test mae: " + repr(lostTS) + ", " + repr(maeTS)) while True: model.fit(xtrain, ytrain, epochs=1, batch_size=1, verbose=2) print("Aim: " + repr(lostTS)) lostTR2, maeTR2 = model.evaluate(xtrain, ytrain, verbose=0) print("Train mae: " + repr(lostTR2) + ", " + repr(maeTR2)) lostTS2, maeTS2 = model.evaluate(xtest, ytest, verbose=0) print("Test mae: " + repr(lostTS2) + ", " + repr(maeTS2)) print("Improve number: " + repr(x)) print(datetime.datetime.now()) if lostTS > lostTS2: print ("imporoved") model.save_weights("F:/Home/Projects/MailRuAI/Codeball2018/nnet/WEXP_B36F.dat") break 

Mengapa tidak bertanya persis 3 lapisan dalam dan hanya konfigurasi seperti itu? Saya tidak tahu sendiri. Namun, intuisi dan hari-hari percobaan justru mengarah ke sana.

Dan akhirnya, pertanyaan terakhir, bagaimana menggunakan jaringan saraf yang sudah dilatih dengan Python di C # tanpa kelas yang sudah jadi? Buat kelas! Dengan konfigurasi jaringan saraf yang begitu sederhana dan mengingat bahwa kita hanya perlu mengimplementasikan fungsi "prediksi" (yaitu, hanya menyapu dari input ke output) itu cukup sederhana. Ini dia:

 public enum Activation { relu, linear, sigmoid }; public class layer { public int Count = 0; public List<List<double>> weights = new List<List<double>>(); public List<double> Ps = new List<double>(); public List<Activation> funcs = new List<Activation>(); public List<double> Values = new List<double>(); public void Add(Activation aact) { Count++; weights.Add(new List<double>()); Ps.Add(0); funcs.Add(aact); Values.Add(0); } public void Add(Activation aact, int acnt) { for (int i = 0; i < acnt; i++) Add(aact); } public void Calculate(List<double> ainps) { for (int i = 0; i < Count; i++) { Values[i] = Ps[i]; for (int j = 0; j < ainps.Count; j++) Values[i] += weights[i][j] * ainps[j]; switch (funcs[i]) { case Activation.linear: break; case Activation.relu: Values[i] = System.Math.Max(0, Values[i]); break; case Activation.sigmoid: Values[i] = (double)(1.0 / (1.0 + System.Math.Exp(-Values[i]))); break; } } } } public class nnet { public int inputCount = 0; public List<layer> layers = new List<layer>(); public layer outputLayer = null; public nnet(int ainputcount, int aoutputcount) { inputCount = ainputcount; outputLayer = new layer(); outputLayer.Add(Activation.linear, aoutputcount); } public List<double> predict(List<double> ainput) { for (int i = 0; i < layers.Count + 1; i++) { List<double> inps = ainput; if (i > 0) inps = layers[i - 1].Values; layer lr = outputLayer; if (i < layers.Count) lr = layers[i]; lr.Calculate(inps); } return outputLayer.Values; } } 

Tetap hanya untuk mengencangkan bobot dari jaringan terlatih (omong-omong, saya mengutip bobot di sini yang benar-benar berfungsi dalam versi terbaru saya):

  public class trained_nnet : nnet { void FillLayer(layer al, double[] atp, double[,] atw) { al.Ps.Clear(); al.Ps.AddRange(atp); al.weights.Clear(); for (int i = 0; i < atw.GetLength(0); i++) { al.weights.Add(new List<double>()); for (int j = 0; j < atw.GetLength(1); j++) { al.weights[i].Add(atw[i, j]); } } } public trained_nnet() : base(6, 1) { layer lr1 = new layer(); lr1.Add(Activation.relu, 6); lr1.Add(Activation.linear, 6); lr1.Add(Activation.sigmoid, 6); base.layers.Add(lr1); layer lr2 = new layer(); lr2.Add(Activation.relu, 3); lr2.Add(Activation.linear, 3); lr2.Add(Activation.sigmoid, 3); base.layers.Add(lr2); layer lr3 = new layer(); lr3.Add(Activation.relu, 3); lr3.Add(Activation.linear, 3); lr3.Add(Activation.sigmoid, 3); base.layers.Add(lr3); double[] t = { 3.6843767166137695, -9.454026222229004, -5.089229106903076, -2.850287437438965, -6.96286153793335, -9.751116752624512, 10.384811401367188, -4.214056968688965, 1.2072025537490845, 1.4019242525100708, -0.13174889981746674, -13.1264066696167, -4.265004634857178, 1.8926845788955688, -0.0813497006893158, -1.4616785049438477, -5.361510753631592, -1.1896661520004272 }; double[,] t2 = { { 0.1477939784526825, 0.03613739833235741, -0.09796690940856934, 1.942456841468811, -0.3508949875831604, -0.5551134347915649 }, { -0.25495094060897827, 0.049018844962120056, -0.15976546704769135, -1.881699562072754, -1.3928385972976685, 0.017490295693278313 }, { 0.314727246761322, -0.7985705733299255, -0.16902890801429749, 0.7290273308753967, -3.3613057136535645, -0.501738965511322 }, { -0.14706645905971527, 0.013889106921851635, -8.41325855255127, 0.08269797265529633, -0.8194255232810974, 0.054869525134563446 }, { -0.11769858002662659, 0.024719441309571266, -32.9736213684082, -0.06565750390291214, -0.38925793766975403, -0.30816638469696045 }, { -0.09536012262105942, -0.4411015212535858, -0.3092011511325836, 0.061532989144325256, -1.3718899488449097, -0.9904148578643799 }, { 0.03862301632761955, -0.2239271104335785, -0.3054073452949524, 0.013336590491235256, -0.0404842384159565, -0.09027290344238281 }, { -0.317527711391449, -0.14433158934116364, 0.06079907342791557, -0.4572157561779022, 0.2782846987247467, 0.17747753858566284 }, { 0.01980031281709671, 0.015361669473350048, -0.03606397658586502, 0.013219496235251427, -0.03483833745121956, -0.01729537360370159 }, { -0.003958317916840315, 0.09587077051401138, -0.08213665336370468, -0.027169639244675636, 0.032037656754255295, -0.030492693185806274 }, { -0.04885690286755562, -0.06349656730890274, 0.013905149884521961, 0.018028201535344124, 0.012719585560262203, 0.002531017642468214 }, { 0.016520477831363678, -0.00018591046682558954, -0.003657651599496603, 0.06888063997030258, -0.2127065807580948, 0.6427022218704224 }, { -0.5308891534805298, 0.13539844751358032, 0.03864796832203865, 1.5582681894302368, -1.929693341255188, -3.2511842250823975 }, { 0.032178860157728195, 1.1472656726837158, -2.020042896270752, -0.05141841620206833, -0.4635908901691437, 0.2636871039867401 }, { 0.01480827759951353, 0.33971744775772095, -0.15343432128429413, 0.03558071702718735, 3.364596366882324, -0.7852638959884644 }, { 0.0028303645085543394, 1.2297841310501099, -0.4412313997745514, 0.3644706606864929, 2.2155861854553223, -0.43303439021110535 }, { -0.3666411340236664, 0.0464097335934639, 5.143652439117432, -2.2230076789855957, 0.3511424660682678, 1.0514445304870605 }, { 0.014482858590781689, -0.4740144610404968, -1.6240901947021484, 1.7327706813812256, -1.5116417407989502, -1.6811648607254028 } }; double[] t3 = { -3.09689998626709, -1.2031112909317017, -7.121585369110107, 2.0653932094573975, -2.8601508140563965, -1.6219528913497925, 0.16301754117012024, -6.890131950378418, 3.8225107192993164 }; double[,] t4 = { { -0.6246452927589417, -0.3575346767902374, 0.6897052526473999, -2.2513232231140137, -0.23217444121837616, 0.17847181856632233, -0.3863859176635742, -0.01201619766652584, 0.050539981573820114, 0.028343766927719116, 0.0034856200218200684, 0.5547005534172058, -0.4277774691581726, -1.0249099731445312, -8.995088577270508, -3.4937169551849365, 0.7673622369766235, -1.6504380702972412 }, { -1.0006977319717407, -0.8660659790039062, -0.0415676049888134, -0.5476861000061035, -0.7828258872032166, -0.05350146442651749, 0.005586389917880297, -0.052493464201688766, 0.07955628633499146, -0.08084911853075027, 0.09794406592845917, -0.031214063987135887, -0.7785998582839966, -0.27977627515792847, -0.4096711277961731, -0.24633635580539703, -1.5932326316833496, -0.5430923104286194 }, { -0.2330777496099472, -0.07477551698684692, -1.0634428262710571, -1.772096872329712, -1.4657013416290283, 0.6256936192512512, -0.1179097518324852, 0.07645376771688461, 0.008837736211717129, 0.030952733010053635, -0.013960030861198902, 1.0339184999465942, 0.20350944995880127, -0.047291483730077744, -4.043337345123291, -0.7629795670509338, -5.41167688369751, -3.7755305767059326 }, { 0.00979659240692854, 0.11435728520154953, -0.4749748706817627, 1.5166815519332886, -5.3047380447387695, 0.9597445130348206, 0.08123911172151566, 0.039479970932006836, -0.01649349369108677, -0.04941410943865776, 0.020120851695537567, -0.16329358518123627, 0.36106961965560913, 0.5348165035247803, 0.11825983971357346, 0.2075480818748474, -1.8661850690841675, 1.4093444347381592 }, { -0.35534173250198364, 0.3471201956272125, -0.2657061517238617, -2.4178225994110107, -3.890836238861084, 0.5999298691749573, -0.10068143904209137, 0.530009388923645, 0.023632165044546127, -0.006245455238968134, 0.031124670058488846, 0.016797777265310287, 1.720144510269165, -0.3200121223926544, 0.17827671766281128, -1.0847045183181763, 0.7679504156112671, 1.1521148681640625 }, { 0.047243088483810425, -0.07313758134841919, -0.13496115803718567, -1.0498348474502563, -2.083388328552246, 0.3018227815628052, 0.019016921520233154, 0.00780009850859642, -0.02416112646460533, -0.012299800291657448, 0.019720694050192833, 0.019809948280453682, -1.637327790260315, 0.09307140856981277, 2.963168144226074, 0.515803337097168, 0.02399904653429985, -3.9851980209350586 }, { -0.6250298023223877, -0.4796958863735199, 0.4311320185661316, -1.4590528011322021, -4.861763000488281, -1.1894060373306274, 0.31154727935791016, -0.028901753947138786, 0.07241783291101456, 0.0573900043964386, -0.16387903690338135, -0.7621306777000427, 2.864539623260498, 1.126343011856079, -0.729159414768219, 15.2516450881958, -0.5845442414283752, -0.2593745291233063 }, { -0.4520488679409027, -0.37348034977912903, -0.22873088717460632, 2.816544532775879, 0.635391891002655, 1.7192658185958862, -0.042334891855716705, -0.012391769327223301, -0.00944773480296135, -0.047271229326725006, 0.045244403183460236, 1.1044175624847412, -2.682516098022461, -1.797003984451294, -5.227936744689941, 0.3994572162628174, -3.361297130584717, -0.16535422205924988 }, { 1.3437395095825195, 0.05596136301755905, -0.6534030437469482, -3.2173333168029785, -3.256056785583496, 3.164973020553589, -0.6149216294288635, 0.3425371050834656, -0.13111716508865356, -0.42127469182014465, -0.0668950155377388, 0.19484268128871918, 2.005012273788452, -3.41219425201416, -0.3146309554576874, -2.1181774139404297, 2.2965285778045654, 5.287317276000977 } }; double[] t5 = { -1.173705816268921, -1.8888208866119385, -2.566594123840332, 0.1278465986251831, 0.05948356166481972, -0.021375492215156555, -1.554726243019104, -2.2256762981414795, 1.3142614364624023 }; double[,] t6 = { { -0.023421021178364754, 0.17735084891319275, -0.1922600418329239, -0.11634820699691772, 0.05003879591822624, 0.07409390062093735, -0.131203755736351, -0.11743484437465668, -1.1311017274856567 }, { -0.6256148219108582, -0.08678799867630005, 0.08910120278596878, -0.06354714930057526, 0.05225379019975662, 0.028936704620718956, -2.069547176361084, 0.16652414202690125, 0.4840211570262909 }, { -0.9266191720962524, 0.1542767435312271, -1.511458396911621, -2.2593629360198975, 0.32768234610557556, 0.728438138961792, 1.4113644361495972, -2.9423279762268066, -1.1225157976150513 }, { -0.31864309310913086, -0.06739992648363113, 1.8643943071365356, 0.12609687447547913, 0.003282073885202408, -0.08565603941679001, 0.22951357066631317, -3.9096572399139404, -0.5148558020591736 }, { 0.0030701414216309786, 0.22653144598007202, -0.1772366166114807, 0.01472154725342989, 0.006688127294182777, 0.029435427859425545, -0.049562305212020874, -0.01126908790320158, -0.09357477724552155 }, { -0.003160204039886594, 0.004133348818868399, 0.003914407920092344, 0.013578329235315323, 0.0036796496715396643, 0.028364477679133415, 0.025828130543231964, -0.030584659427404404, -0.0449080727994442 }, { -0.15649960935115814, 0.7045242786407471, 4.971825122833252, 0.26150253415107727, 0.25615766644477844, -0.007457265630364418, 0.4002840220928192, -4.386100769042969, -0.14405106008052826 }, { -1.283564805984497, -1.0451316833496094, -9.010445594787598, -0.23629669845104218, 0.8792487978935242, 0.12951965630054474, 2.7414908409118652, -10.04093074798584, 0.08805646747350693 }, { 0.5142691731452942, 0.27933982014656067, 17.242839813232422, -0.14753387868404388, 0.35601550340652466, -0.03304799646139145, -0.3745580017566681, 3.6696081161499023, 0.18306805193424225 } }; double[] t7 = { 0.057645831257104874 }; double[,] t8 = { { 0.02502649463713169, 0.030625218525528908, -0.04921620339155197, -0.06382419914007187, -0.0018273837631568313, -0.002946096006780863, -0.3073849678039551, -0.0770358145236969, 0.44145819544792175 } }; FillLayer(lr1, t, t2); FillLayer(lr2, t3, t4); FillLayer(lr3, t5, t6); FillLayer(outputLayer, t7, t8); } } 

Panggilan jaringan saraf:

  public double StateRatingByNNet() { double result = 0; List<double> xdata = new List<double>(); double sign = 1; if (ball.velocity.Z < 0) sign = -1; double signX = 1; if (ball.velocity.X * sign < 0) signX = -1; xdata.Add(ball.velocity.X * sign * signX); xdata.Add(ball.velocity.Y); xdata.Add(ball.velocity.Z * sign); xdata.Add(ball.position.X * sign * signX); xdata.Add(ball.position.Y); xdata.Add(ball.position.Z * sign); List<double> o = nnet.predict(xdata); return result + o[0] * sign; } 

Terima kasih atas minat Anda!

Source: https://habr.com/ru/post/id439886/


All Articles