Usar datos cifrados para el aprendizaje automático sin descifrarlos


Usar datos cifrados para el aprendizaje automático sin descifrarlos
Este artículo trata sobre técnicas criptográficas avanzadas. Esta es solo una descripción general de la investigación realizada por Julia Computing. No use los ejemplos dados aquí en aplicaciones comerciales. Siempre consulte con los criptógrafos antes de aplicar la criptografía.

Aquí puede descargar el paquete que implementa toda la magia, y aquí está el código que se discute en el artículo.

Introduccion


Digamos que acaba de desarrollar un nuevo modelo de aprendizaje automático (por supuesto, usando Flux.jl ). Y ahora desea comenzar a implementarlo para sus usuarios. ¿Cómo vas a hacer esto? Probablemente la forma más fácil es dar el modelo a los usuarios y dejar que se ejecute localmente en sus datos. Pero este enfoque tiene desventajas:

  1. Los modelos de aprendizaje automático son grandes, y las computadoras de los usuarios pueden no tener suficientes recursos informáticos o de disco.
  2. Los modelos de aprendizaje automático a menudo se actualizan, y puede que no sea conveniente para usted enviar regularmente grandes cantidades de datos a través de la red.
  3. El desarrollo del modelo lleva mucho tiempo y requiere una gran cantidad de recursos informáticos. Y es posible que desee una compensación por esto en forma de una tarifa por usar su modelo.

Luego, por lo general, recuerdan que el modelo se puede proporcionar en la nube a través de la API. En los últimos años, han aparecido muchos de estos servicios; cada gran plataforma en la nube ofrece servicios similares a los desarrolladores corporativos. Pero los usuarios potenciales se enfrentan a un dilema obvio: ahora sus datos se procesan en un servidor remoto, que puede no ser confiable. Esto tiene claras implicaciones éticas y legales que limitan el uso de dichos servicios. En las industrias reguladas, especialmente los servicios de salud y financieros, a menudo no es posible enviar datos de pacientes y clientes a terceros para su procesamiento.

¿Alguna otra opción?

Resulta que hay! Los descubrimientos recientes en criptografía permiten computar con datos sin decodificarlos . Por ejemplo, un usuario envía datos encriptados (por ejemplo, imágenes) a la API en la nube, que lanza un modelo de aprendizaje automático, y luego envía una respuesta encriptada. En ningún momento se descifran los datos, el proveedor de la nube no obtiene acceso a las imágenes de origen y no puede descifrar el pronóstico calculado. ¿Cómo es esto posible? Veamos el ejemplo de crear un servicio para el reconocimiento de escritura a mano en imágenes encriptadas del conjunto de datos MNIST.

Acerca del cifrado homomórfico


La capacidad de realizar cálculos con datos cifrados se conoce comúnmente como "computación segura". Esta es un área grande para la investigación, con numerosos enfoques de criptografía que dependen de todo tipo de escenarios de aplicación. Nos centraremos en una técnica llamada "encriptación homomórfica". En dicho sistema, las siguientes operaciones generalmente están disponibles para nosotros:

  • pub_key, eval_key, priv_key = keygen()
  • encrypted = encrypt(pub_key, plaintext)
  • decrypted = decrypt(priv_key, encrypted)
  • encrypted′ = eval(eval_key, f, encrypted)

Las primeras tres operaciones son simples y familiares para todos los que ya han utilizado algoritmos de cifrado asimétricos (por ejemplo, si se conectó a través de TLS). Toda la magia sucede en la última operación. Durante el cifrado, evalúa la función f y devuelve otro valor cifrado calculado de acuerdo con el resultado de evaluar f en el valor cifrado. Esta característica le dio a su enfoque su nombre. La evaluación está relacionada con la operación de cifrado:

 f(decrypt(priv_key, encrypted)) == decrypt(priv_key, eval(eval_key, f, encrypted)) 

Del mismo modo, utilizando un valor cifrado, podemos evaluar homomorfismos arbitrarios f .

Las funciones f admite f dependen de los esquemas criptográficos y las operaciones admitidas. Si solo f admite una f (por ejemplo, f = + ), el circuito se llama "parcialmente homomórfico". Si f puede ser un conjunto completo de puertas de enlace, sobre la base de las cuales se pueden crear esquemas arbitrarios, entonces para un tamaño limitado de un esquema esto se llama otro tipo de cálculo parcialmente homomórfico - "algo homomórfico", y para un tamaño ilimitado - cálculo "completamente homomórfico". Puede convertir "de alguna manera" en un cifrado completamente homomórfico utilizando la técnica de arranque, pero esto está más allá del alcance de nuestro artículo. El cifrado totalmente homomórfico es un descubrimiento relativamente reciente, el primer esquema de trabajo (aunque poco práctico) fue publicado por Craig Gentry en 2009 . Hay una serie de esquemas completamente homomórficos posteriores (y prácticos). También hay paquetes de software que implementan cualitativamente estos esquemas. La mayoría de las veces usan Microsoft SEAL y PALISADE . Además, recientemente abrí el código de implementación para estos algoritmos de Pure Julia . Para este artículo, utilizaremos el cifrado CKKS implementado en él.

Descripción general de CKS


CKKS (por los nombres de los autores del trabajo científico Cheon-Kim-Kim-Song, que propuso el algoritmo en 2016) es un esquema de cifrado homomórfico que permite la evaluación homomórfica de las siguientes operaciones primitivas:

  • La adición por elementos de las longitudes de n vectores de números complejos.
  • Multiplicación por elementos de las longitudes de n vectores complejos.
  • Rotar elementos (en el contexto de un circshift ) en un vector.
  • Emparejamiento integrado de elementos vectoriales.

El parámetro n depende del nivel deseado de seguridad y precisión, y generalmente es bastante alto. En nuestro ejemplo, será igual a 4096 (un valor más alto aumenta la seguridad, pero también es más difícil en los cálculos, se escala aproximadamente como n log n ).

Además, los cálculos con CKKS son ruidosos . Por lo tanto, los resultados son aproximados, y se debe tener cuidado de que los resultados se evalúen con suficiente precisión para no afectar la exactitud del resultado.

Por otro lado, tales restricciones no son inusuales para los desarrolladores de paquetes de aprendizaje automático. Los aceleradores especiales como la GPU también suelen operar con vectores de números. Además, para muchos desarrolladores, los números de coma flotante a veces parecen ruidosos debido a la influencia de algoritmos de selección, subprocesos múltiples, etc. Quiero enfatizar que la diferencia clave aquí es que los cálculos aritméticos con números de coma flotante son inicialmente deterministas, incluso si esto no es obvio debido a la complejidad de la implementación, aunque las primitivas CKKS son realmente ruidosas. Pero tal vez esto les permita a los usuarios comprender que el ruido no es tan aterrador como podría parecer.

Ahora veamos cómo puede realizar estas operaciones en Julia (nota: se seleccionan parámetros muy inseguros, con estas operaciones solo ilustramos el uso de la biblioteca en REPL).

 julia> using ToyFHE # Let's play with 8 element vectors julia> N = 8; # Choose some parameters - we'll talk about it later julia> ℛ = NegacyclicRing(2N, (40, 40, 40)) ℤ₁₃₂₉₂₂₇₉₉₇₅₆₈₀₈₁₄₅₇₄₀₂₇₀₁₂₀₇₁₀₄₂₄₈₂₅₇/(x¹⁶ + 1) # We'll use CKKS julia> params = CKKSParams(ℛ) CKKS parameters # We need to pick a scaling factor for a numbers - again we'll talk about that later julia> Tscale = FixedRational{2^40} FixedRational{1099511627776,T} where T # Let's start with a plain Vector of zeros julia> plain = CKKSEncoding{Tscale}(zero(ℛ)) 8-element CKKSEncoding{FixedRational{1099511627776,T} where T} with indices 0:7: 0.0 + 0.0im 0.0 + 0.0im 0.0 + 0.0im 0.0 + 0.0im 0.0 + 0.0im 0.0 + 0.0im 0.0 + 0.0im 0.0 + 0.0im # Ok, we're ready to get started, but first we'll need some keys julia> kp = keygen(params) CKKS key pair julia> kp.priv CKKS private key julia> kp.pub CKKS public key # Alright, let's encrypt some things: julia> foreach(i->plain[i] = i+1, 0:7); plain 8-element CKKSEncoding{FixedRational{1099511627776,T} where T} with indices 0:7: 1.0 + 0.0im 2.0 + 0.0im 3.0 + 0.0im 4.0 + 0.0im 5.0 + 0.0im 6.0 + 0.0im 7.0 + 0.0im 8.0 + 0.0im julia> c = encrypt(kp.pub, plain) CKKS ciphertext (length 2, encoding CKKSEncoding{FixedRational{1099511627776,T} where T}) # And decrypt it again julia> decrypt(kp.priv, c) 8-element CKKSEncoding{FixedRational{1099511627776,T} where T} with indices 0:7: 0.9999999999995506 - 2.7335193113350057e-16im 1.9999999999989408 - 3.885780586188048e-16im 3.000000000000205 + 1.6772825551165524e-16im 4.000000000000538 - 3.885780586188048e-16im 4.999999999998865 + 8.382500573679615e-17im 6.000000000000185 + 4.996003610813204e-16im 7.000000000001043 - 2.0024593503998215e-16im 8.000000000000673 + 4.996003610813204e-16im # Note that we had some noise. Let's go through all the primitive operations we'll need: julia> decrypt(kp.priv, c+c) 8-element CKKSEncoding{FixedRational{1099511627776,T} where T} with indices 0:7: 1.9999999999991012 - 5.467038622670011e-16im 3.9999999999978817 - 7.771561172376096e-16im 6.00000000000041 + 3.354565110233105e-16im 8.000000000001076 - 7.771561172376096e-16im 9.99999999999773 + 1.676500114735923e-16im 12.00000000000037 + 9.992007221626409e-16im 14.000000000002085 - 4.004918700799643e-16im 16.000000000001346 + 9.992007221626409e-16im julia> csq = c*c CKKS ciphertext (length 3, encoding CKKSEncoding{FixedRational{1208925819614629174706176,T} where T}) julia> decrypt(kp.priv, csq) 8-element CKKSEncoding{FixedRational{1208925819614629174706176,T} where T} with indices 0:7: 0.9999999999991012 - 2.350516767363621e-15im 3.9999999999957616 - 5.773159728050814e-15im 9.000000000001226 - 2.534464540987068e-15im 16.000000000004306 - 2.220446049250313e-15im 24.99999999998865 + 2.0903753311370056e-15im 36.00000000000222 + 4.884981308350689e-15im 49.000000000014595 + 1.0182491378134327e-15im 64.00000000001077 + 4.884981308350689e-15im 

Tan simple! Un lector atento podría notar que CSQ es ligeramente diferente del texto cifrado anterior. En particular, el texto cifrado tiene "longitud 3" y la escala es mucho mayor. Una explicación de qué es esto y qué se necesita está más allá del alcance de este artículo. Baste decir que necesitamos bajar los valores antes de continuar con los cálculos, de lo contrario el "lugar" terminará en el texto cifrado. Afortunadamente, podemos reducir cada uno de los dos valores aumentados:

 # To get back down to length 2, we need to `keyswitch` (aka # relinerarize), which requires an evaluation key. Generating # this requires the private key. In a real application we would # have generated this up front and sent it along with the encrypted # data, but since we have the private key, we can just do it now. julia> ek = keygen(EvalMultKey, kp.priv) CKKS multiplication key julia> csq_length2 = keyswitch(ek, csq) CKKS ciphertext (length 2, encoding CKKSEncoding{FixedRational{1208925819614629174706176,T} where T}) # Getting the scale back down is done using modswitching. julia> csq_smaller = modswitch(csq_length2) CKKS ciphertext (length 2, encoding CKKSEncoding{FixedRational{1.099511626783e12,T} where T}) # And it still decrypts correctly (though note we've lost some precision) julia> decrypt(kp.priv, csq_smaller) 8-element CKKSEncoding{FixedRational{1.099511626783e12,T} where T} with indices 0:7: 0.9999999999802469 - 5.005163520332181e-11im 3.9999999999957723 - 1.0468514951188039e-11im 8.999999999998249 - 4.7588542623100616e-12im 16.000000000023014 - 1.0413447889166631e-11im 24.999999999955193 - 6.187833723406491e-12im 36.000000000002345 + 1.860733715346631e-13im 49.00000000001647 - 1.442396043149794e-12im 63.999999999988695 - 1.0722489563648028e-10im 

Además, la conmutación de modulación (abreviatura de conmutación de módulo, conmutación de módulo) reduce el tamaño del módulo de texto cifrado, por lo que no podemos continuar haciendo esto indefinidamente (utilizamos un esquema de cifrado algo homomórfico):

 julia> ℛ # Remember the ring we initially created ℤ₁₃₂₉₂₂₇₉₉₇₅₆₈₀₈₁₄₅₇₄₀₂₇₀₁₂₀₇₁₀₄₂₄₈₂₅₇/(x¹⁶ + 1) julia> ToyFHE.ring(csq_smaller) # It shrunk! ℤ₁₂₀₈₉₂₅₈₂₀₁₄₄₅₉₃₇₇₉₃₃₁₅₅₃/(x¹⁶ + 1)</code>     —  (rotations).      keyswitch,       (evaluation key,     ): <source lang="julia">julia> gk = keygen(GaloisKey, kp.priv; steps=2) CKKS galois key (element 25) julia> decrypt(circshift(c, gk)) decrypt(kp, circshift(c, gk)) 8-element CKKSEncoding{FixedRational{1099511627776,T} where T} with indices 0:7: 7.000000000001042 + 5.68459112632516e-16im 8.000000000000673 + 5.551115123125783e-17im 0.999999999999551 - 2.308655353580721e-16im 1.9999999999989408 + 2.7755575615628914e-16im 3.000000000000205 - 6.009767921608429e-16im 4.000000000000538 + 5.551115123125783e-17im 4.999999999998865 + 4.133860996136768e-17im 6.000000000000185 - 1.6653345369377348e-16im # And let's compare to doing the same on the plaintext julia> circshift(plain, 2) 8-element OffsetArray(::Array{Complex{Float64},1}, 0:7) with eltype Complex{Float64} with indices 0:7: 7.0 + 0.0im 8.0 + 0.0im 1.0 + 0.0im 2.0 + 0.0im 3.0 + 0.0im 4.0 + 0.0im 5.0 + 0.0im 6.0 + 0.0im 

Cubrimos los conceptos básicos del uso de la biblioteca HE. Pero antes de seguir usando estas primitivas para calcular los pronósticos de la red neuronal, veamos el proceso de aprendizaje.

Modelo de aprendizaje automático


Si no está familiarizado con el aprendizaje automático o la biblioteca Flux.jl, le recomiendo una revisión rápida de la documentación de Flux.jl o vea una introducción gratuita al aprendizaje automático , porque solo discutiremos los cambios en la aplicación del modelo a los datos cifrados.

Comencemos usando la red neuronal convolucional del zoológico Flux . Llevaremos a cabo el mismo ciclo de capacitación, con preparación de datos, etc., solo configuraremos un poco el modelo. Aquí esta:

 function reshape_and_vcat(x) let y=reshape(x, 64, 4, size(x, 4)) vcat((y[:,i,:] for i=axes(y,2))...) end end model = Chain( # First convolution, operating upon a 28x28 image Conv((7, 7), 1=>4, stride=(3,3), x->x.^2), reshape_and_vcat, Dense(256, 64, x->x.^2), Dense(64, 10), ) 

Este es el mismo modelo que en el trabajo "Computación y aplicación de matriz de outsourcing seguro a redes neuronales" , que utiliza el mismo esquema criptográfico con dos diferencias: 1) en aras de la simplicidad, no ciframos el modelo en sí, y 2) después de cada capa tenemos Se usan vectores bayesianos (en Flux, esto se hace de manera predeterminada), no estoy seguro de qué era en el trabajo mencionado. Quizás, debido al segundo punto, la precisión en el conjunto de pruebas de nuestro modelo resultó ser ligeramente mayor (98.6% versus 98.1%), pero las diferencias hiperparamétricas también podrían ser la razón.

Inusual (para aquellos que tienen experiencia en aprendizaje automático) es la activación de funciones x.^2 . Muy a menudo en tales casos usan tanh , relu o algo más imaginario. Pero aunque estas funciones (especialmente relu ) se calculan fácilmente para valores de texto ordinarios, sin embargo, pueden requerir muchos recursos informáticos para evaluarlas en forma cifrada (generalmente estimamos la aproximación polinómica). Afortunadamente, en este caso x.^2 funciona muy bien.

El resto del ciclo de aprendizaje permaneció igual. softmax del modelo para la logitcrossentropy función de logitcrossentropy (puede dejarlo y evaluar softmax después del descifrado en el cliente). El código completo para entrenar el modelo se encuentra en GitHub , se ejecuta en unos minutos en cualquier tarjeta de video nueva.

Operaciones efectivas


Ahora sabemos qué operaciones debemos realizar:

  • Coagulación
  • Elemento cuadrado.
  • Multiplicación matricial.

Con la cuadratura todo es simple, ya lo hemos examinado anteriormente, por lo que consideraremos otras dos operaciones. Suponemos que la longitud del paquete de datos es 64 (puede observar que los parámetros del modelo y el tamaño del paquete se eligen para aprovechar el vector de 4096 elementos que obtuvimos como resultado de una elección realista de parámetros).

Coagulación


Recordemos cómo funciona la coagulación. Tome una ventana (en nuestro caso 7x7) de la matriz de entrada original, y cada elemento de la ventana se multiplica por un elemento de máscara de convolución. Luego movemos la ventana a algún paso (en nuestro caso, el paso es 3, es decir, movemos 3 elementos) y repetimos el proceso (con la misma máscara de convolución). La animación del proceso ( fuente ) para la convolución 3x3 con el paso (2, 2) muestra a continuación (matriz azul - entrada, verde - salida):


Además, realizamos convolución en cuatro "canales" diferentes (es decir, repetimos la convolución 3 veces más con diferentes máscaras).

Ahora que sabemos qué hacer, queda por entender cómo. Somos afortunados de que la convolución sea la primera operación en nuestro modelo. Como resultado, para ahorrar recursos, podemos preprocesar los datos en el cliente y luego encriptarlos (sin usar pesos). Hagamos esto:

  • Primero, calculamos cada ventana de convolución (es decir, una muestra de 7x7 de las imágenes de origen), lo que nos da 64 matrices de 7x7 para cada imagen de entrada. Tenga en cuenta que para una ventana de 7x7 en incrementos de 2, habrá ventanas de convolución de 8x8 para evaluar la imagen de entrada de 28x28.
  • Recolectemos en un vector las mismas posiciones en cada ventana. Es decir, para cada imagen tendremos un vector de 64 elementos, o un vector de elementos de 64x64 para un paquete de tamaño 64 (un total de 49 matrices de 64x64).
  • Lo encriptaremos.

Luego, la coagulación simplemente se convierte en una multiplicación escalar de toda la matriz con el elemento de máscara correspondiente. Y resumiendo más tarde los 49 elementos, obtenemos el resultado del plegado. Así es como se vería la implementación de esta estrategia (en texto plano):

 function public_preprocess(batch) ka = OffsetArray(0:7, 0:7) # Create feature extracted matrix I = [[batch[i′*3 .+ (1:7), j′*3 .+ (1:7), 1, k] for i′=ka, j′=ka] for k = 1:64] # Reshape into the ciphertext Iᵢⱼ = [[I[k][l...][i,j] for k=1:64, l=product(ka, ka)] for i=1:7, j=1:7] end Iᵢⱼ = public_preprocess(batch) # Evaluate the convolution weights = model.layers[1].weight conv_weights = reverse(reverse(weights, dims=1), dims=2) conved = [sum(Iᵢⱼ[i,j]*conv_weights[i,j,1,channel] for i=1:7, j=1:7) for channel = 1:4] conved = map(((x,b),)->x .+ b, zip(conved, model.layers[1].bias)) 

Este (módulo para cambiar la dimensión) (módulo: cambiar el orden de los tamaños) da la misma respuesta que la operación model.layers[1](batch) .

Añadir operaciones de cifrado:

 Iᵢⱼ = public_preprocess(batch) C_Iᵢⱼ = map(Iᵢⱼ) do Iij plain = CKKSEncoding{Tscale}(zero(plaintext_space(ckks_params))) plain .= OffsetArray(vec(Iij), 0:(N÷2-1)) encrypt(kp, plain) end weights = model.layers[1].weight conv_weights = reverse(reverse(weights, dims=1), dims=2) conved3 = [sum(C_Iᵢⱼ[i,j]*conv_weights[i,j,1,channel] for i=1:7, j=1:7) for channel = 1:4] conved2 = map(((x,b),)->x .+ b, zip(conved3, model.layers[1].bias)) conved1 = map(ToyFHE.modswitch, conved2) 

Tenga en cuenta que no se requiere el interruptor de llave aquí porque los pesos son públicos. Por lo tanto, no aumentamos la longitud del texto cifrado.

Multiplicación de matrices


Pasando a la multiplicación matricial, podemos usar la rotación de elementos en el vector para cambiar el orden de los índices de multiplicación. Considere la colocación en fila de elementos de matriz en un vector. Si cambiamos el vector por un múltiplo del tamaño de la fila, obtenemos el efecto de la rotación de la columna, que es una primitiva suficiente para implementar la multiplicación de matrices (al menos matrices cuadradas). Probemos

 function matmul_square_reordered(weights, x) sum(1:size(weights, 1)) do k # We rotate the columns of the LHS and take the diagonal weight_diag = diag(circshift(weights, (0,(k-1)))) # We rotate the rows of the RHS x_rotated = circshift(x, (k-1,0)) # We do an elementwise, broadcast multiply weight_diag .* x_rotated end end function matmul_reorderd(weights, x) sum(partition(1:256, 64)) do range matmul_square_reordered(weights[:, range], x[range, :]) end end fc1_weights = model.layers[3].W x = rand(Float64, 256, 64) @assert (fc1_weights*x) ≈ matmul_reorderd(fc1_weights, x) 

Por supuesto, para la multiplicación matricial general, se requiere algo más complicado, pero por ahora esto es suficiente.

Mejorando la técnica


Ahora todos los componentes de nuestra técnica funcionan. Aquí está el código completo (excepto para configurar las opciones de selección y cosas similares):

 ek = keygen(EvalMultKey, kp.priv) gk = keygen(GaloisKey, kp.priv; steps=64) Iᵢⱼ = public_preprocess(batch) C_Iᵢⱼ = map(Iᵢⱼ) do Iij plain = CKKSEncoding{Tscale}(zero(plaintext_space(ckks_params))) plain .= OffsetArray(vec(Iij), 0:(N÷2-1)) encrypt(kp, plain) end weights = model.layers[1].weight conv_weights = reverse(reverse(weights, dims=1), dims=2) conved3 = [sum(C_Iᵢⱼ[i,j]*conv_weights[i,j,1,channel] for i=1:7, j=1:7) for channel = 1:4] conved2 = map(((x,b),)->x .+ b, zip(conved3, model.layers[1].bias)) conved1 = map(ToyFHE.modswitch, conved2) Csqed1 = map(x->x*x, conved1) Csqed1 = map(x->keyswitch(ek, x), Csqed1) Csqed1 = map(ToyFHE.modswitch, Csqed1) function encrypted_matmul(gk, weights, x::ToyFHE.CipherText) result = repeat(diag(weights), inner=64).*x rotated = x for k = 2:64 rotated = ToyFHE.rotate(gk, rotated) result += repeat(diag(circshift(weights, (0,(k-1)))), inner=64) .* rotated end result end fq1_weights = model.layers[3].W Cfq1 = sum(enumerate(partition(1:256, 64))) do (i,range) encrypted_matmul(gk, fq1_weights[:, range], Csqed1[i]) end Cfq1 = Cfq1 .+ OffsetArray(repeat(model.layers[3].b, inner=64), 0:4095) Cfq1 = modswitch(Cfq1) Csqed2 = Cfq1*Cfq1 Csqed2 = keyswitch(ek, Csqed2) Csqed2 = modswitch(Csqed2) function naive_rectangular_matmul(gk, weights, x) @assert size(weights, 1) < size(weights, 2) weights = vcat(weights, zeros(eltype(weights), size(weights, 2)-size(weights, 1), size(weights, 2))) encrypted_matmul(gk, weights, x) end fq2_weights = model.layers[4].W Cresult = naive_rectangular_matmul(gk, fq2_weights, Csqed2) Cresult = Cresult .+ OffsetArray(repeat(vcat(model.layers[4].b, zeros(54)), inner=64), 0:4095) 

No se ve muy bien, pero si hiciste todo esto, deberías entender cada paso.
Ahora pensemos en qué abstracciones podrían simplificar nuestras vidas. Estamos dejando el campo de la cartografía y el aprendizaje automático y avanzando hacia la arquitectura del lenguaje de programación, así que aprovechemos el hecho de que Julia le permite usar y crear poderosas abstracciones. Por ejemplo, puede encapsular todo el proceso de extracción de convoluciones en su tipo de matriz:

 using BlockArrays """ ExplodedConvArray{T, Dims, Storage} <: AbstractArray{T, 4} Represents a an `nxmx1xb` array of images, but rearranged into a series of convolution windows. Evaluating a convolution compatible with `Dims` on this array is achievable through a sequence of scalar multiplications and sums on the underling storage. """ struct ExplodedConvArray{T, Dims, Storage} <: AbstractArray{T, 4} # sx*sy matrix of b*(dx*dy) matrices of extracted elements # where (sx, sy) = kernel_size(Dims) # (dx, dy) = output_size(DenseConvDims(...)) cdims::Dims x::Matrix{Storage} function ExplodedConvArray{T, Dims, Storage}(cdims::Dims, storage::Matrix{Storage}) where {T, Dims, Storage} @assert all(==(size(storage[1])), size.(storage)) new{T, Dims, Storage}(cdims, storage) end end Base.size(ex::ExplodedConvArray) = (NNlib.input_size(ex.cdims)..., 1, size(ex.x[1], 1)) function ExplodedConvArray{T}(cdims, batch::AbstractArray{T, 4}) where {T} x, y = NNlib.output_size(cdims) kx, ky = NNlib.kernel_size(cdims) stridex, stridey = NNlib.stride(cdims) kax = OffsetArray(0:x-1, 0:x-1) kay = OffsetArray(0:x-1, 0:x-1) I = [[batch[i′*stridex .+ (1:kx), j′*stridey .+ (1:ky), 1, k] for i′=kax, j′=kay] for k = 1:size(batch, 4)] Iᵢⱼ = [[I[k][l...][i,j] for k=1:size(batch, 4), l=product(kax, kay)] for (i,j) in product(1:kx, 1:ky)] ExplodedConvArray{T, typeof(cdims), eltype(Iᵢⱼ)}(cdims, Iᵢⱼ) end function NNlib.conv(x::ExplodedConvArray{<:Any, Dims}, weights::AbstractArray{<:Any, 4}, cdims::Dims) where {Dims<:ConvDims} blocks = reshape([ Base.ReshapedArray(sum(xx[i,j]*weights[i,j,1,channel] for i=1:7, j=1:7), (NNlib.output_size(cdims)...,1,size(x, 4)), ()) for channel = 1:4 ],(1,1,4,1)) BlockArrays._BlockArray(blocks, BlockArrays.BlockSizes([8], [8], [1,1,1,1], [64])) end 

Aquí nuevamente utilizamos BlockArrays para representar una matriz de 8x8x4x64 como cuatro matrices de 8x8x1x64 como en el código fuente. Ahora la presentación de la primera etapa se ha vuelto mucho más bella, al menos con arreglos no cifrados:

 julia> cdims = DenseConvDims(batch, model.layers[1].weight; stride=(3,3), padding=(0,0,0,0), dilation=(1,1)) DenseConvDims: (28, 28, 1) * (7, 7) -> (8, 8, 4), stride: (3, 3) pad: (0, 0, 0, 0), dil: (1, 1), flip: false julia> a = ExplodedConvArray{eltype(batch)}(cdims, batch); julia> model(a) 10×64 Array{Float32,2}: [snip] 

Ahora, ¿cómo conectamos esto con el cifrado? Para hacer esto, necesitas:

  1. Cifre la estructura ( ExplodedConvArray ) para que obtengamos el texto cifrado de cada campo. Las operaciones con una estructura encriptada de este tipo verificarán lo que la función haría con la estructura original, y harán lo mismo homomórficamente.
  2. Intercepte ciertas operaciones para realizarlas de manera diferente en un contexto cifrado.

Afortunadamente, Julia nos proporciona una abstracción para esto: un complemento de compilación que utiliza el mecanismo Cassette.jl . No le diré qué es y cómo funciona, diré brevemente que puede determinar el contexto, por ejemplo, Encrypted , y luego define las reglas de cómo deberían funcionar las operaciones en este contexto. Por ejemplo, puede escribir esto para el segundo requisito:

 # Define Matrix multiplication between an array and an encrypted block array function (*::Encrypted{typeof(*)})(a::Array{T, 2}, b::Encrypted{<:BlockArray{T, 2}}) where {T} sum(a*b for (i,range) in enumerate(partition(1:size(a, 2), size(b.blocks[1], 1)))) end # Define Matrix multiplication between an array and an encrypted array function (*::Encrypted{typeof(*)})(a::Array{T, 2}, b::Encrypted{Array{T, 2}}) where {T} result = repeat(diag(a), inner=size(a, 1)).*x rotated = b for k = 2:size(a, 2) rotated = ToyFHE.rotate(GaloisKey(*), rotated) result += repeat(diag(circshift(a, (0,(k-1)))), inner=size(a, 1)) .* rotated end result end 

Como resultado, el usuario podrá escribir todo lo anterior con una cantidad mínima de trabajo manual:

 kp = keygen(ckks_params) ek = keygen(EvalMultKey, kp.priv) gk = keygen(GaloisKey, kp.priv; steps=64) # Create evaluation context ctx = Encrypted(ek, gk) # Do public preprocessing batch = ExplodedConvArray{eltype(batch)}(cdims, batch); # Run on encrypted data under the encryption context Cresult = ctx(model)(encrypt(kp.pub, batch)) # Decrypt the answer decrypt(kp, Cresult) 

, . ( ℛ, modswitch, keyswitch ..) , . , , , , .

Conclusión


— . Julia . RAMPARTS ( paper , JuliaCon talk ) : Julia- - PALISADE. Julia Computing RAMPARTS Verona, . , . . , , .

, ToyFHE . , , , .

Source: https://habr.com/ru/post/478514/


All Articles