Utilisation de données chiffrées pour l'apprentissage automatique sans les déchiffrer


Utilisation de données chiffrées pour l'apprentissage automatique sans les déchiffrer
Cet article décrit les techniques cryptographiques avancées. Ceci est juste un aperçu des recherches menées par Julia Computing. N'utilisez pas les exemples donnés ici dans les applications commerciales. Consultez toujours les cryptographes avant d'appliquer la cryptographie.

Ici, vous pouvez télécharger le package qui implémente toute la magie, et voici le code qui est discuté dans l'article.

Présentation


Supposons que vous venez de développer un nouveau modÚle d'apprentissage automatique sympa (bien sûr, en utilisant Flux.jl ). Et maintenant, vous voulez commencer à le déployer pour vos utilisateurs. Comment allez-vous faire cela? Le moyen le plus simple est probablement de donner le modÚle aux utilisateurs et de le laisser s'exécuter localement sur leurs données. Mais cette approche présente des inconvénients:

  1. Les modĂšles d'apprentissage automatique sont volumineux et les ordinateurs des utilisateurs peuvent ne pas disposer de suffisamment de ressources informatiques ou disque.
  2. Les modĂšles d'apprentissage automatique sont souvent mis Ă  jour et il peut ne pas ĂȘtre pratique pour vous d'envoyer rĂ©guliĂšrement de grandes quantitĂ©s de donnĂ©es sur le rĂ©seau.
  3. Le dĂ©veloppement de modĂšles prend du temps et nĂ©cessite une grande quantitĂ© de ressources informatiques. Et vous voudrez peut-ĂȘtre une compensation sous forme de frais pour l'utilisation de votre modĂšle.

Ensuite, ils se souviennent gĂ©nĂ©ralement que le modĂšle peut ĂȘtre fourni dans le cloud via l'API. Au cours des derniĂšres annĂ©es, de nombreux services de ce type sont apparus; chaque grande plateforme cloud offre des services similaires aux dĂ©veloppeurs d'entreprise. Mais les utilisateurs potentiels sont confrontĂ©s Ă  un dilemme Ă©vident: dĂ©sormais, leurs donnĂ©es sont traitĂ©es sur un serveur distant, qui peut ne pas ĂȘtre fiable. Cela a des implications Ă©thiques et juridiques claires qui limitent l'utilisation de ces services. Dans les secteurs rĂ©glementĂ©s, notamment les soins de santĂ© et les services financiers, il est souvent impossible d'envoyer des donnĂ©es sur les patients et les clients Ă  des tiers pour traitement.

D'autres options?

Il s'avĂšre que c'est le cas! Les dĂ©couvertes rĂ©centes en cryptographie permettent de calculer avec des donnĂ©es sans les dĂ©coder . Par exemple, un utilisateur envoie des donnĂ©es chiffrĂ©es (par exemple, des images) Ă  l'API cloud, qui lance un modĂšle d'apprentissage automatique, puis envoie une rĂ©ponse chiffrĂ©e. À aucun moment, les donnĂ©es ne sont dĂ©chiffrĂ©es, le fournisseur de cloud n'a pas accĂšs aux images sources et ne peut pas dĂ©chiffrer les prĂ©visions calculĂ©es. Comment est-ce possible? Voyons l'exemple de la crĂ©ation d'un service de reconnaissance de l'Ă©criture manuscrite sur des images chiffrĂ©es Ă  partir de l'ensemble de donnĂ©es MNIST.

À propos du chiffrement homomorphe


La capacité d'effectuer des calculs avec des données chiffrées est communément appelée «informatique sécurisée». Il s'agit d'un vaste domaine de recherche, avec de nombreuses approches de la cryptographie en fonction de toutes sortes de scénarios d'application. Nous nous concentrerons sur une technique appelée «cryptage homomorphique». Dans un tel systÚme, les opérations suivantes sont généralement disponibles pour nous:

  • pub_key, eval_key, priv_key = keygen()
  • encrypted = encrypt(pub_key, plaintext)
  • decrypted = decrypt(priv_key, encrypted)
  • encryptedâ€Č = eval(eval_key, f, encrypted)

Les trois premiĂšres opĂ©rations sont simples et familiĂšres Ă  tous ceux qui ont dĂ©jĂ  utilisĂ© des algorithmes de chiffrement asymĂ©triques (par exemple, si vous vous ĂȘtes connectĂ© via TLS). Toute magie opĂšre lors de la derniĂšre opĂ©ration. Lors du chiffrement, il Ă©value la fonction f et renvoie une autre valeur chiffrĂ©e calculĂ©e en fonction du rĂ©sultat de l'Ă©valuation f sur la valeur chiffrĂ©e. Cette caractĂ©ristique a donnĂ© son nom Ă  son approche. L'Ă©valuation est liĂ©e Ă  l'opĂ©ration de chiffrement:

 f(decrypt(priv_key, encrypted)) == decrypt(priv_key, eval(eval_key, f, encrypted)) 

De mĂȘme, en utilisant une valeur chiffrĂ©e, nous pouvons Ă©valuer des homomorphismes arbitraires f .

Les fonctions f prises en charge dĂ©pendent des schĂ©mas cryptographiques et des opĂ©rations prises en charge. Si un seul f pris en charge (par exemple, f = + ), alors le circuit est appelĂ© «partiellement homomorphe». Si f peut ĂȘtre un ensemble complet de passerelles, sur la base desquelles des schĂ©mas arbitraires peuvent ĂȘtre crĂ©Ă©s, alors pour une taille limitĂ©e du schĂ©ma, cela s'appelle un autre type de calcul partiellement homomorphe - "quelque peu homomorphique", et pour une taille illimitĂ©e - calcul "complĂštement homomorphique". Vous pouvez transformer "en quelque sorte" un cryptage complĂštement homomorphe en utilisant la technique d'amorçage, mais cela dĂ©passe le cadre de notre article. Le cryptage entiĂšrement homomorphe est une dĂ©couverte relativement rĂ©cente, le premier schĂ©ma de travail (bien que peu pratique) a Ă©tĂ© publiĂ© par Craig Gentry en 2009 . Il existe un certain nombre de schĂ©mas ultĂ©rieurs (et pratiques) complĂštement homomorphes. Il existe Ă©galement des progiciels qui mettent en Ɠuvre qualitativement ces schĂ©mas. Le plus souvent, ils utilisent Microsoft SEAL et PALISADE . De plus, j'ai rĂ©cemment ouvert le code d'implĂ©mentation de ces algorithmes Pure Julia . Pour cet article, nous utiliserons le cryptage CKKS implĂ©mentĂ©.

Présentation de CKS


CKKS (par les noms des auteurs de l' ouvrage scientifique Cheon-Kim-Kim-Song, qui a proposé l'algorithme en 2016) est un schéma de cryptage homomorphique qui permet une évaluation homomorphique des opérations primitives suivantes:

  • L'addition Ă©lĂ©ment par Ă©lĂ©ment des longueurs de n vecteurs de nombres complexes.
  • Multiplication par Ă©lĂ©ment des longueurs de n vecteurs complexes.
  • Faire pivoter (dans le contexte du circshift ) les Ă©lĂ©ments d'un vecteur.
  • Appariement intĂ©grĂ© des Ă©lĂ©ments vectoriels.

Le paramÚtre n dépend du niveau de sécurité et de précision souhaité et est généralement assez élevé. Dans notre exemple, il sera égal à 4096 (une valeur plus élevée augmente la sécurité, mais est également plus difficile dans les calculs, elle évolue approximativement comme n log n ).

De plus, les calculs utilisant CKKS sont bruyants . Par conséquent, les résultats sont approximatifs et il faut veiller à ce que les résultats soient évalués avec une précision suffisante pour ne pas affecter l'exactitude du résultat.

D'un autre cĂŽtĂ©, de telles restrictions ne sont pas inhabituelles pour les dĂ©veloppeurs de packages d'apprentissage automatique. Des accĂ©lĂ©rateurs spĂ©ciaux comme le GPU fonctionnent Ă©galement gĂ©nĂ©ralement avec des vecteurs numĂ©riques. De plus, pour de nombreux dĂ©veloppeurs, les nombres Ă  virgule flottante semblent parfois bruyants en raison de l'influence des algorithmes de sĂ©lection, du multithreading, etc. Je tiens Ă  souligner que la principale diffĂ©rence ici est que les calculs arithmĂ©tiques avec des nombres Ă  virgule flottante sont initialement dĂ©terministes, mĂȘme si cela n'est pas Ă©vident en raison de la complexitĂ© de la mise en Ɠuvre, bien que les primitives CKKS soient vraiment bruyantes. Mais cela permet peut-ĂȘtre aux utilisateurs de comprendre que le bruit n'est pas aussi effrayant qu'il n'y paraĂźt.

Voyons maintenant comment vous pouvez effectuer ces opérations dans Julia (remarque: des paramÚtres trÚs dangereux sont sélectionnés, avec ces opérations, nous n'illustrons que l'utilisation de la bibliothÚque dans REPL).

 julia> using ToyFHE # Let's play with 8 element vectors julia> N = 8; # Choose some parameters - we'll talk about it later julia> ℛ = NegacyclicRing(2N, (40, 40, 40)) ℀₁₃₂₉₂₂₇₉₉₇₅₆₈₀₈₁₄₅₇₄₀₂₇₀₁₂₀₇₁₀₄₂₄₈₂₅₇/(xÂč⁶ + 1) # We'll use CKKS julia> params = CKKSParams(ℛ) CKKS parameters # We need to pick a scaling factor for a numbers - again we'll talk about that later julia> Tscale = FixedRational{2^40} FixedRational{1099511627776,T} where T # Let's start with a plain Vector of zeros julia> plain = CKKSEncoding{Tscale}(zero(ℛ)) 8-element CKKSEncoding{FixedRational{1099511627776,T} where T} with indices 0:7: 0.0 + 0.0im 0.0 + 0.0im 0.0 + 0.0im 0.0 + 0.0im 0.0 + 0.0im 0.0 + 0.0im 0.0 + 0.0im 0.0 + 0.0im # Ok, we're ready to get started, but first we'll need some keys julia> kp = keygen(params) CKKS key pair julia> kp.priv CKKS private key julia> kp.pub CKKS public key # Alright, let's encrypt some things: julia> foreach(i->plain[i] = i+1, 0:7); plain 8-element CKKSEncoding{FixedRational{1099511627776,T} where T} with indices 0:7: 1.0 + 0.0im 2.0 + 0.0im 3.0 + 0.0im 4.0 + 0.0im 5.0 + 0.0im 6.0 + 0.0im 7.0 + 0.0im 8.0 + 0.0im julia> c = encrypt(kp.pub, plain) CKKS ciphertext (length 2, encoding CKKSEncoding{FixedRational{1099511627776,T} where T}) # And decrypt it again julia> decrypt(kp.priv, c) 8-element CKKSEncoding{FixedRational{1099511627776,T} where T} with indices 0:7: 0.9999999999995506 - 2.7335193113350057e-16im 1.9999999999989408 - 3.885780586188048e-16im 3.000000000000205 + 1.6772825551165524e-16im 4.000000000000538 - 3.885780586188048e-16im 4.999999999998865 + 8.382500573679615e-17im 6.000000000000185 + 4.996003610813204e-16im 7.000000000001043 - 2.0024593503998215e-16im 8.000000000000673 + 4.996003610813204e-16im # Note that we had some noise. Let's go through all the primitive operations we'll need: julia> decrypt(kp.priv, c+c) 8-element CKKSEncoding{FixedRational{1099511627776,T} where T} with indices 0:7: 1.9999999999991012 - 5.467038622670011e-16im 3.9999999999978817 - 7.771561172376096e-16im 6.00000000000041 + 3.354565110233105e-16im 8.000000000001076 - 7.771561172376096e-16im 9.99999999999773 + 1.676500114735923e-16im 12.00000000000037 + 9.992007221626409e-16im 14.000000000002085 - 4.004918700799643e-16im 16.000000000001346 + 9.992007221626409e-16im julia> csq = c*c CKKS ciphertext (length 3, encoding CKKSEncoding{FixedRational{1208925819614629174706176,T} where T}) julia> decrypt(kp.priv, csq) 8-element CKKSEncoding{FixedRational{1208925819614629174706176,T} where T} with indices 0:7: 0.9999999999991012 - 2.350516767363621e-15im 3.9999999999957616 - 5.773159728050814e-15im 9.000000000001226 - 2.534464540987068e-15im 16.000000000004306 - 2.220446049250313e-15im 24.99999999998865 + 2.0903753311370056e-15im 36.00000000000222 + 4.884981308350689e-15im 49.000000000014595 + 1.0182491378134327e-15im 64.00000000001077 + 4.884981308350689e-15im 

Si simple! Un lecteur attentif peut remarquer que CSQ est légÚrement différent du texte chiffré précédent. En particulier, le texte chiffré a une «longueur 3» et l'échelle est beaucoup plus grande. Une explication de ce que c'est et de ce qui est nécessaire dépasse le cadre de cet article. Il suffit de dire que nous devons baisser les valeurs avant de poursuivre les calculs, sinon le "lieu" se terminera dans le texte chiffré. Heureusement, nous pouvons réduire chacune des deux valeurs augmentées:

 # To get back down to length 2, we need to `keyswitch` (aka # relinerarize), which requires an evaluation key. Generating # this requires the private key. In a real application we would # have generated this up front and sent it along with the encrypted # data, but since we have the private key, we can just do it now. julia> ek = keygen(EvalMultKey, kp.priv) CKKS multiplication key julia> csq_length2 = keyswitch(ek, csq) CKKS ciphertext (length 2, encoding CKKSEncoding{FixedRational{1208925819614629174706176,T} where T}) # Getting the scale back down is done using modswitching. julia> csq_smaller = modswitch(csq_length2) CKKS ciphertext (length 2, encoding CKKSEncoding{FixedRational{1.099511626783e12,T} where T}) # And it still decrypts correctly (though note we've lost some precision) julia> decrypt(kp.priv, csq_smaller) 8-element CKKSEncoding{FixedRational{1.099511626783e12,T} where T} with indices 0:7: 0.9999999999802469 - 5.005163520332181e-11im 3.9999999999957723 - 1.0468514951188039e-11im 8.999999999998249 - 4.7588542623100616e-12im 16.000000000023014 - 1.0413447889166631e-11im 24.999999999955193 - 6.187833723406491e-12im 36.000000000002345 + 1.860733715346631e-13im 49.00000000001647 - 1.442396043149794e-12im 63.999999999988695 - 1.0722489563648028e-10im 

De plus, la commutation de modules (abréviation de commutation de module, commutation de module) réduit la taille du module de texte chiffré, nous ne pouvons donc pas continuer à le faire indéfiniment (nous utilisons un schéma de cryptage quelque peu homomorphique):

 julia> ℛ # Remember the ring we initially created ℀₁₃₂₉₂₂₇₉₉₇₅₆₈₀₈₁₄₅₇₄₀₂₇₀₁₂₀₇₁₀₄₂₄₈₂₅₇/(xÂč⁶ + 1) julia> ToyFHE.ring(csq_smaller) # It shrunk! ℀₁₂₀₈₉₂₅₈₂₀₁₄₄₅₉₃₇₇₉₃₃₁₅₅₃/(xÂč⁶ + 1)</code>     —  (rotations).      keyswitch,       (evaluation key,     ): <source lang="julia">julia> gk = keygen(GaloisKey, kp.priv; steps=2) CKKS galois key (element 25) julia> decrypt(circshift(c, gk)) decrypt(kp, circshift(c, gk)) 8-element CKKSEncoding{FixedRational{1099511627776,T} where T} with indices 0:7: 7.000000000001042 + 5.68459112632516e-16im 8.000000000000673 + 5.551115123125783e-17im 0.999999999999551 - 2.308655353580721e-16im 1.9999999999989408 + 2.7755575615628914e-16im 3.000000000000205 - 6.009767921608429e-16im 4.000000000000538 + 5.551115123125783e-17im 4.999999999998865 + 4.133860996136768e-17im 6.000000000000185 - 1.6653345369377348e-16im # And let's compare to doing the same on the plaintext julia> circshift(plain, 2) 8-element OffsetArray(::Array{Complex{Float64},1}, 0:7) with eltype Complex{Float64} with indices 0:7: 7.0 + 0.0im 8.0 + 0.0im 1.0 + 0.0im 2.0 + 0.0im 3.0 + 0.0im 4.0 + 0.0im 5.0 + 0.0im 6.0 + 0.0im 

Nous avons couvert les bases de l'utilisation de la bibliothÚque HE. Mais avant de passer à l'utilisation de ces primitives pour calculer les prévisions du réseau de neurones, regardons le processus d'apprentissage.

ModĂšle d'apprentissage automatique


Si vous n'ĂȘtes pas familier avec l'apprentissage automatique ou la bibliothĂšque Flux.jl, je vous recommande de parcourir rapidement la documentation de Flux.jl ou de voir une introduction gratuite Ă  l'apprentissage automatique , car nous ne discuterons que des modifications apportĂ©es Ă  l'application du modĂšle aux donnĂ©es chiffrĂ©es.

Commençons par utiliser le rĂ©seau de neurones convolutifs du zoo Flux . Nous allons effectuer le mĂȘme cycle de formation, avec la prĂ©paration des donnĂ©es et ainsi de suite, il suffit de configurer un peu le modĂšle. Le voici:

 function reshape_and_vcat(x) let y=reshape(x, 64, 4, size(x, 4)) vcat((y[:,i,:] for i=axes(y,2))...) end end model = Chain( # First convolution, operating upon a 28x28 image Conv((7, 7), 1=>4, stride=(3,3), x->x.^2), reshape_and_vcat, Dense(256, 64, x->x.^2), Dense(64, 10), ) 

Il s'agit du mĂȘme modĂšle que dans le travail «Secure Outsourced Matrix Computation and Application to Neural Networks» , qui utilise le mĂȘme schĂ©ma cryptographique avec deux diffĂ©rences: 1) pour des raisons de simplicitĂ©, nous n'avons pas chiffrĂ© le modĂšle lui-mĂȘme, et 2) aprĂšs chaque couche que nous avons Des vecteurs bayĂ©siens sont utilisĂ©s (dans Flux, cela se fait par dĂ©faut), je ne sais pas ce que c'Ă©tait dans le travail mentionnĂ©. Peut-ĂȘtre, en raison du deuxiĂšme point, la prĂ©cision sur l'ensemble de test de notre modĂšle s'est avĂ©rĂ©e ĂȘtre lĂ©gĂšrement plus Ă©levĂ©e (98,6% contre 98,1%), mais des diffĂ©rences hyperparamĂ©triques pourraient Ă©galement ĂȘtre la raison.

L'activation des fonctions x.^2 est inhabituelle (pour ceux qui ont de l'expérience en apprentissage automatique). Le plus souvent, dans de tels cas, ils utilisent le tanh , le relu ou quelque chose de plus fantaisiste. Mais bien que ces fonctions (en particulier relu ) soient facilement calculées pour les valeurs de texte ordinaires, cependant, elles peuvent nécessiter beaucoup de ressources informatiques pour les évaluer sous forme cryptée (nous estimons généralement l'approximation polynomiale). Heureusement, dans ce cas, x.^2 fonctionne trÚs bien.

Le reste du cycle d'apprentissage est restĂ© le mĂȘme. Nous avons supprimĂ© softmax du modĂšle pour la fonction de perte de logitcrossentropy (vous pouvez le laisser et Ă©valuer softmax aprĂšs dĂ©cryptage sur le client). Le code complet pour la formation du modĂšle se trouve sur GitHub , il s'exĂ©cute en quelques minutes sur n'importe quelle nouvelle carte vidĂ©o.

Opérations efficaces


Nous savons maintenant quelles opérations nous devons effectuer:

  • Coagulation.
  • ÉlĂ©ment au carrĂ©.
  • Multiplication matricielle.

Avec la quadrature, tout est simple, nous l'avons dĂ©jĂ  examinĂ© ci-dessus, nous allons donc considĂ©rer deux autres opĂ©rations. Nous supposons que la longueur du paquet de donnĂ©es est de 64 (vous remarquerez peut-ĂȘtre que les paramĂštres du modĂšle et la taille du paquet sont choisis de maniĂšre Ă  tirer parti du vecteur Ă  4096 Ă©lĂ©ments que nous avons obtenu Ă  la suite d'un choix rĂ©aliste de paramĂštres).

La coagulation


Rappelez-vous comment fonctionne la coagulation. Prenez une fenĂȘtre (dans notre cas 7x7) du tableau d'entrĂ©e d'origine, et chaque Ă©lĂ©ment de fenĂȘtre est multipliĂ© par un Ă©lĂ©ment de masque de convolution. Ensuite, nous dĂ©plaçons la fenĂȘtre Ă  une Ă©tape (dans notre cas, l'Ă©tape est 3, c'est-Ă -dire que nous dĂ©plaçons 3 Ă©lĂ©ments) et rĂ©pĂ©tons le processus (avec le mĂȘme masque de convolution). L'animation du processus ( source ) pour la convolution 3x3 avec l'Ă©tape (2, 2) montrĂ©e ci-dessous (tableau bleu - entrĂ©e, vert - sortie):


De plus, nous effectuons la convolution dans quatre «canaux» différents (c'est-à-dire que nous répétons la convolution 3 fois de plus avec des masques différents).

Maintenant que nous savons quoi faire, il reste à comprendre comment. Nous avons la chance que la convolution soit la premiÚre opération de notre modÚle. Par conséquent, afin d'économiser des ressources, nous pouvons prétraiter les données sur le client, puis les chiffrer (sans utiliser de poids). Faisons ça:

  • Tout d'abord, nous calculons chaque fenĂȘtre de convolution (c'est-Ă -dire un Ă©chantillon 7x7 Ă  partir des images source), ce qui nous donne 64 matrices 7x7 pour chaque image d'entrĂ©e. Notez que pour une fenĂȘtre 7x7 par incrĂ©ments de 2, il y aura des fenĂȘtres de convolution 8x8 pour Ă©valuer l'image d'entrĂ©e 28x28.
  • Collectons dans un vecteur les mĂȘmes positions dans chaque fenĂȘtre. Autrement dit, pour chaque image, nous aurons un vecteur Ă  64 Ă©lĂ©ments, ou un vecteur d'Ă©lĂ©ments 64x64 pour un paquet de taille 64 (un total de 49 matrices 64x64).
  • Nous crypterons.

La coagulation se transforme alors simplement en une multiplication scalaire de la matrice entiĂšre avec l'Ă©lĂ©ment de masque correspondant. Et en rĂ©sumant plus tard les 49 Ă©lĂ©ments, nous obtenons le rĂ©sultat du pliage. Voici Ă  quoi pourrait ressembler la mise en Ɠuvre de cette stratĂ©gie (en texte brut):

 function public_preprocess(batch) ka = OffsetArray(0:7, 0:7) # Create feature extracted matrix I = [[batch[iâ€Č*3 .+ (1:7), jâ€Č*3 .+ (1:7), 1, k] for iâ€Č=ka, jâ€Č=ka] for k = 1:64] # Reshape into the ciphertext Iá”ąâ±Œ = [[I[k][l...][i,j] for k=1:64, l=product(ka, ka)] for i=1:7, j=1:7] end Iá”ąâ±Œ = public_preprocess(batch) # Evaluate the convolution weights = model.layers[1].weight conv_weights = reverse(reverse(weights, dims=1), dims=2) conved = [sum(Iá”ąâ±Œ[i,j]*conv_weights[i,j,1,channel] for i=1:7, j=1:7) for channel = 1:4] conved = map(((x,b),)->x .+ b, zip(conved, model.layers[1].bias)) 

Ce (module pour changer la dimension) (modulo - changer l'ordre des tailles) donne la mĂȘme rĂ©ponse que l'opĂ©ration model.layers[1](batch) .

Ajoutez des opérations de chiffrement:

 Iá”ąâ±Œ = public_preprocess(batch) C_Iá”ąâ±Œ = map(Iá”ąâ±Œ) do Iij plain = CKKSEncoding{Tscale}(zero(plaintext_space(ckks_params))) plain .= OffsetArray(vec(Iij), 0:(NĂ·2-1)) encrypt(kp, plain) end weights = model.layers[1].weight conv_weights = reverse(reverse(weights, dims=1), dims=2) conved3 = [sum(C_Iá”ąâ±Œ[i,j]*conv_weights[i,j,1,channel] for i=1:7, j=1:7) for channel = 1:4] conved2 = map(((x,b),)->x .+ b, zip(conved3, model.layers[1].bias)) conved1 = map(ToyFHE.modswitch, conved2) 

Veuillez noter que l'interrupteur à clé n'est pas requis ici car les poids sont publics. Nous n'augmentons donc pas la longueur du texte chiffré.

Multiplication matricielle


Passant à la multiplication matricielle, nous pouvons utiliser la rotation des éléments dans le vecteur pour changer l'ordre des indices de multiplication. Envisagez le placement en ligne des éléments de matrice dans un vecteur. Si nous décalons le vecteur d'un multiple de la taille de la ligne, nous obtenons l'effet de la rotation des colonnes, qui est une primitive suffisante pour implémenter la multiplication matricielle (au moins les matrices carrées). Essayons:

 function matmul_square_reordered(weights, x) sum(1:size(weights, 1)) do k # We rotate the columns of the LHS and take the diagonal weight_diag = diag(circshift(weights, (0,(k-1)))) # We rotate the rows of the RHS x_rotated = circshift(x, (k-1,0)) # We do an elementwise, broadcast multiply weight_diag .* x_rotated end end function matmul_reorderd(weights, x) sum(partition(1:256, 64)) do range matmul_square_reordered(weights[:, range], x[range, :]) end end fc1_weights = model.layers[3].W x = rand(Float64, 256, 64) @assert (fc1_weights*x) ≈ matmul_reorderd(fc1_weights, x) 

Bien sûr, pour la multiplication matricielle générale, quelque chose de plus compliqué est nécessaire, mais pour l'instant cela suffit.

Améliorer la technique


Maintenant, tous les composants de notre technique fonctionnent. Voici le code entier (sauf pour définir les options de sélection et des choses similaires):

 ek = keygen(EvalMultKey, kp.priv) gk = keygen(GaloisKey, kp.priv; steps=64) Iá”ąâ±Œ = public_preprocess(batch) C_Iá”ąâ±Œ = map(Iá”ąâ±Œ) do Iij plain = CKKSEncoding{Tscale}(zero(plaintext_space(ckks_params))) plain .= OffsetArray(vec(Iij), 0:(NĂ·2-1)) encrypt(kp, plain) end weights = model.layers[1].weight conv_weights = reverse(reverse(weights, dims=1), dims=2) conved3 = [sum(C_Iá”ąâ±Œ[i,j]*conv_weights[i,j,1,channel] for i=1:7, j=1:7) for channel = 1:4] conved2 = map(((x,b),)->x .+ b, zip(conved3, model.layers[1].bias)) conved1 = map(ToyFHE.modswitch, conved2) Csqed1 = map(x->x*x, conved1) Csqed1 = map(x->keyswitch(ek, x), Csqed1) Csqed1 = map(ToyFHE.modswitch, Csqed1) function encrypted_matmul(gk, weights, x::ToyFHE.CipherText) result = repeat(diag(weights), inner=64).*x rotated = x for k = 2:64 rotated = ToyFHE.rotate(gk, rotated) result += repeat(diag(circshift(weights, (0,(k-1)))), inner=64) .* rotated end result end fq1_weights = model.layers[3].W Cfq1 = sum(enumerate(partition(1:256, 64))) do (i,range) encrypted_matmul(gk, fq1_weights[:, range], Csqed1[i]) end Cfq1 = Cfq1 .+ OffsetArray(repeat(model.layers[3].b, inner=64), 0:4095) Cfq1 = modswitch(Cfq1) Csqed2 = Cfq1*Cfq1 Csqed2 = keyswitch(ek, Csqed2) Csqed2 = modswitch(Csqed2) function naive_rectangular_matmul(gk, weights, x) @assert size(weights, 1) < size(weights, 2) weights = vcat(weights, zeros(eltype(weights), size(weights, 2)-size(weights, 1), size(weights, 2))) encrypted_matmul(gk, weights, x) end fq2_weights = model.layers[4].W Cresult = naive_rectangular_matmul(gk, fq2_weights, Csqed2) Cresult = Cresult .+ OffsetArray(repeat(vcat(model.layers[4].b, zeros(54)), inner=64), 0:4095) 

Cela n'a pas l'air trop soigné, mais si vous avez fait tout cela, vous devez comprendre chaque étape.
Réfléchissons maintenant aux abstractions qui pourraient simplifier nos vies. Nous quittons le domaine de la cartographie et du machine learning et passons à l'architecture du langage de programmation, profitons donc du fait que Julia vous permet d'utiliser et de créer des abstractions puissantes. Par exemple, vous pouvez encapsuler l'ensemble du processus d'extraction des convolutions dans votre type de tableau:

 using BlockArrays """ ExplodedConvArray{T, Dims, Storage} <: AbstractArray{T, 4} Represents a an `nxmx1xb` array of images, but rearranged into a series of convolution windows. Evaluating a convolution compatible with `Dims` on this array is achievable through a sequence of scalar multiplications and sums on the underling storage. """ struct ExplodedConvArray{T, Dims, Storage} <: AbstractArray{T, 4} # sx*sy matrix of b*(dx*dy) matrices of extracted elements # where (sx, sy) = kernel_size(Dims) # (dx, dy) = output_size(DenseConvDims(...)) cdims::Dims x::Matrix{Storage} function ExplodedConvArray{T, Dims, Storage}(cdims::Dims, storage::Matrix{Storage}) where {T, Dims, Storage} @assert all(==(size(storage[1])), size.(storage)) new{T, Dims, Storage}(cdims, storage) end end Base.size(ex::ExplodedConvArray) = (NNlib.input_size(ex.cdims)..., 1, size(ex.x[1], 1)) function ExplodedConvArray{T}(cdims, batch::AbstractArray{T, 4}) where {T} x, y = NNlib.output_size(cdims) kx, ky = NNlib.kernel_size(cdims) stridex, stridey = NNlib.stride(cdims) kax = OffsetArray(0:x-1, 0:x-1) kay = OffsetArray(0:x-1, 0:x-1) I = [[batch[iâ€Č*stridex .+ (1:kx), jâ€Č*stridey .+ (1:ky), 1, k] for iâ€Č=kax, jâ€Č=kay] for k = 1:size(batch, 4)] Iá”ąâ±Œ = [[I[k][l...][i,j] for k=1:size(batch, 4), l=product(kax, kay)] for (i,j) in product(1:kx, 1:ky)] ExplodedConvArray{T, typeof(cdims), eltype(Iá”ąâ±Œ)}(cdims, Iá”ąâ±Œ) end function NNlib.conv(x::ExplodedConvArray{<:Any, Dims}, weights::AbstractArray{<:Any, 4}, cdims::Dims) where {Dims<:ConvDims} blocks = reshape([ Base.ReshapedArray(sum(xx[i,j]*weights[i,j,1,channel] for i=1:7, j=1:7), (NNlib.output_size(cdims)...,1,size(x, 4)), ()) for channel = 1:4 ],(1,1,4,1)) BlockArrays._BlockArray(blocks, BlockArrays.BlockSizes([8], [8], [1,1,1,1], [64])) end 

Ici, nous avons de nouveau utilisé BlockArrays pour représenter un tableau 8x8x4x64 comme quatre tableaux 8x8x1x64 comme dans le code source. Maintenant, la présentation de la premiÚre étape est devenue beaucoup plus belle, au moins avec des tableaux non chiffrés:

 julia> cdims = DenseConvDims(batch, model.layers[1].weight; stride=(3,3), padding=(0,0,0,0), dilation=(1,1)) DenseConvDims: (28, 28, 1) * (7, 7) -> (8, 8, 4), stride: (3, 3) pad: (0, 0, 0, 0), dil: (1, 1), flip: false julia> a = ExplodedConvArray{eltype(batch)}(cdims, batch); julia> model(a) 10×64 Array{Float32,2}: [snip] 

Maintenant, comment pouvons-nous connecter cela avec le cryptage? Pour ce faire, vous avez besoin de:

  1. Chiffrez la structure ( ExplodedConvArray ) afin que nous obtenions le texte chiffrĂ© pour chaque champ. Les opĂ©rations avec une telle structure cryptĂ©e vĂ©rifieront ce que la fonction ferait avec la structure d'origine et feront la mĂȘme chose de maniĂšre homomorphe.
  2. Intercepter certaines opérations afin de les effectuer différemment dans un contexte chiffré.

Heureusement, Julia nous fournit une abstraction pour cela: un plugin de compilation qui utilise le mécanisme Cassette.jl . Je ne vous dirai pas ce que c'est et comment cela fonctionne, je dirai briÚvement qu'il peut déterminer le contexte, par exemple, Encrypted , puis il définit les rÚgles de fonctionnement des opérations dans ce contexte. Par exemple, vous pouvez écrire ceci pour la deuxiÚme exigence:

 # Define Matrix multiplication between an array and an encrypted block array function (*::Encrypted{typeof(*)})(a::Array{T, 2}, b::Encrypted{<:BlockArray{T, 2}}) where {T} sum(a*b for (i,range) in enumerate(partition(1:size(a, 2), size(b.blocks[1], 1)))) end # Define Matrix multiplication between an array and an encrypted array function (*::Encrypted{typeof(*)})(a::Array{T, 2}, b::Encrypted{Array{T, 2}}) where {T} result = repeat(diag(a), inner=size(a, 1)).*x rotated = b for k = 2:size(a, 2) rotated = ToyFHE.rotate(GaloisKey(*), rotated) result += repeat(diag(circshift(a, (0,(k-1)))), inner=size(a, 1)) .* rotated end result end 

En conséquence, l'utilisateur pourra écrire tout ce qui précÚde avec un minimum de travail manuel:

 kp = keygen(ckks_params) ek = keygen(EvalMultKey, kp.priv) gk = keygen(GaloisKey, kp.priv; steps=64) # Create evaluation context ctx = Encrypted(ek, gk) # Do public preprocessing batch = ExplodedConvArray{eltype(batch)}(cdims, batch); # Run on encrypted data under the encryption context Cresult = ctx(model)(encrypt(kp.pub, batch)) # Decrypt the answer decrypt(kp, Cresult) 

, . ( ℛ, modswitch, keyswitch ..) , . , , , , .

Conclusion


— . Julia . RAMPARTS ( paper , JuliaCon talk ) : Julia- - PALISADE. Julia Computing RAMPARTS Verona, . , . . , , .

, ToyFHE . , , , .

Source: https://habr.com/ru/post/fr478514/


All Articles