"""
Python library for working with encrypted data within nilDB queries and
replies.
"""
from __future__ import annotations
from typing import Union, Optional, Sequence
import doctest
import base64
import secrets
import hashlib
import hmac
from lagrange import lagrange
import bcl
import pailliers
_PAILLIER_KEY_LENGTH = 2048
"""Length in bits of Paillier keys."""
_PLAINTEXT_SIGNED_INTEGER_MIN = -2147483648
"""Minimum plaintext 32-bit signed integer value that can be encrypted."""
_PLAINTEXT_SIGNED_INTEGER_MAX = 2147483647
"""Maximum plaintext 32-bit signed integer value that can be encrypted."""
_SECRET_SHARED_SIGNED_INTEGER_MODULUS = (2 ** 32) + 15
"""Modulus to use for additive secret sharing of 32-bit signed integers."""
_PLAINTEXT_STRING_BUFFER_LEN_MAX = 4096
"""Maximum length of plaintext string values that can be encrypted."""
_HASH = hashlib.sha512
"""Hash function used for HKDF and matching."""
def _hkdf_extract(salt: bytes, input_key: bytes) -> bytes:
"""
Extracts a pseudorandom key (PRK) using HMAC with the given salt and input key material.
If the salt is empty, a zero-filled byte string of the same length as the hash function's
digest size is used.
"""
if len(salt) == 0:
salt = bytes([0] * _HASH().digest_size)
return hmac.new(salt, input_key, _HASH).digest()
def _hkdf_expand(pseudo_random_key: bytes, info: bytes, length: int) -> bytes:
"""
Expands the pseudo_random_key into an output key material (OKM) of the desired length using
HMAC-based expansion.
"""
t = b''
okm = b''
i = 0
while len(okm) < length:
i += 1
t = hmac.new(pseudo_random_key, t + info + bytes([i]), _HASH).digest()
okm += t
return okm[:length]
def _hkdf(length: int, input_key: bytes, salt: bytes = b'', info: bytes = b'') -> bytes:
"""
Extract a pseudorandom key of `length` from `input_key` and optionally `salt` and `info`.
"""
prk = _hkdf_extract(salt, input_key)
return _hkdf_expand(prk, info, length)
def _random_bytes(length: int, seed: Optional[bytes] = None, salt: Optional[bytes] = None) -> bytes:
"""
Return a random :obj:`bytes` value of the specified length (using
the seed if one is supplied).
"""
if seed is not None:
return _hkdf(length, seed, b'' if salt is None else salt)
return secrets.token_bytes(length)
def _random_int(
minimum: int,
maximum: int,
seed: Optional[bytes] = None
) -> int:
"""
Return a random integer value within the specified range (using
the seed if one is supplied) by leveraging rejection sampling.
>>> _random_int(-1, 1)
Traceback (most recent call last):
...
ValueError: minimum must be 0 or 1
>>> _random_int(1, -1)
Traceback (most recent call last):
...
ValueError: maximum must be greater than the minimum and less than the modulus
"""
if minimum < 0 or minimum > 1:
raise ValueError('minimum must be 0 or 1')
if maximum <= minimum or maximum >= _SECRET_SHARED_SIGNED_INTEGER_MODULUS:
raise ValueError(
'maximum must be greater than the minimum and less than the modulus'
)
# Deterministically generate an integer in the specified range
# using the supplied seed. This specific technique is implemented
# explicitly for compatibility with corresponding libraries for
# other languages and platforms.
if seed is not None:
range_ = maximum - minimum
integer = None
index = 0
while integer is None or integer > range_:
bytes_ = bytearray(_random_bytes(8, seed, index.to_bytes(64, 'little')))
index += 1
bytes_[4] &= 1
bytes_[5] &= 0
bytes_[6] &= 0
bytes_[7] &= 0
small = int.from_bytes(bytes_[:4], 'little')
large = int.from_bytes(bytes_[4:], 'little')
integer = small + large * (2 ** 32)
return minimum + integer
return minimum + secrets.randbelow(maximum + 1 - minimum)
def _shamirs_eval(poly, x, prime):
"""
Evaluates polynomial (coefficient tuple) at x.
"""
accum = 0
for coeff in reversed(poly):
accum *= x
accum += coeff
accum %= prime
return accum
def _shamirs_shares(
secret,
total_shares,
minimum_shares,
prime=_SECRET_SHARED_SIGNED_INTEGER_MODULUS
):
"""
Generates a random Shamir pool for a given secret and returns share points.
>>> _shamirs_shares(123, 2, 3)
Traceback (most recent call last):
...
ValueError: total number of shares cannot be less than the minimum number of shares
"""
if minimum_shares > total_shares:
raise ValueError(
'total number of shares cannot be less than the minimum number of shares'
)
poly = [secret] + [secrets.randbelow(prime - 1) for _ in range(minimum_shares - 1)]
points = [[i, _shamirs_eval(poly, i, prime)] for i in range(1, total_shares + 1)]
return points
def _shamirs_recover(shares, prime=_SECRET_SHARED_SIGNED_INTEGER_MODULUS):
"""
Recover the secret value from the supplied share instances.
>>> _shamirs_recover([[0, 123]])
123
>>> _shamirs_recover([[0, 123], [1, 123], [2, 123]])
123
"""
return lagrange(shares, prime)
def _shamirs_add(shares_a, shares_b, prime=_SECRET_SHARED_SIGNED_INTEGER_MODULUS):
"""
Adds two sets of shares pointwise, assuming they use the same indices.
>>> _shamirs_add([(0, 123), (1, 456)], [(0, 123), (1, 456)])
[[0, 246], [1, 912]]
>>> _shamirs_add([(0, 123), (1, 456)], [(0, 123)])
Traceback (most recent call last):
...
ValueError: shares sets must have the same length
>>> _shamirs_add([(0, 123), (1, 456)], [(0, 123), (2, 456)])
Traceback (most recent call last):
...
ValueError: shares must have the same indices
"""
if len(shares_a) != len(shares_b):
raise ValueError('shares sets must have the same length')
if [i for (i, _) in shares_a] != [i for (i, _) in shares_b]:
raise ValueError('shares must have the same indices')
return [
[i, (v + w) % prime]
for (i, v), (j, w) in zip(shares_a, shares_b)
if i == j
]
def _pack(b: bytes) -> str:
"""
Encode a bytes-like object as a Base64 string (for compatibility with JSON).
"""
return base64.b64encode(b).decode('ascii')
def _unpack(s: str) -> bytes:
"""
Decode a bytes-like object from its Base64 string encoding.
"""
return base64.b64decode(s)
def _encode(value: Union[int, str, bytes]) -> bytes:
"""
Encode an integer, string, or binary plaintext as a binary value.
The encoding includes information about the type of the value in
the first byte (to enable decoding without any additional context).
>>> _encode(123).hex()
'007b00008000000000'
>>> _encode('abc').hex()
'01616263'
>>> _encode(bytes([1, 2, 3])).hex()
'02010203'
If a value cannot be encoded, an exception is raised.
>>> _encode([1, 2, 3])
Traceback (most recent call last):
...
ValueError: cannot encode value
"""
if isinstance(value, int):
return (
bytes([0]) +
(value - _PLAINTEXT_SIGNED_INTEGER_MIN).to_bytes(8, 'little')
)
if isinstance(value, str):
return bytes([1]) + value.encode('UTF-8')
if isinstance(value, bytes):
return bytes([2]) + value
raise ValueError('cannot encode value')
def _decode(value: bytes) -> Union[int, str, bytes]:
"""
Decode a binary value back into an integer, string, or binary plaintext.
>>> _decode(_encode(123))
123
>>> _decode(_encode('abc'))
'abc'
>>> _decode(_encode(bytes([1, 2, 3])))
b'\\x01\\x02\\x03'
If a value cannot be decoded, an exception is raised.
>>> _decode([1, 2, 3])
Traceback (most recent call last):
...
TypeError: can only decode from a bytes value
>>> _decode(bytes([3]))
Traceback (most recent call last):
...
ValueError: cannot decode value
"""
if not isinstance(value, bytes):
raise TypeError('can only decode from a bytes value')
if value[0] == 0: # Indicates encoded value is a 32-bit signed integer.
integer = int.from_bytes(value[1:], 'little')
return integer + _PLAINTEXT_SIGNED_INTEGER_MIN
if value[0] == 1: # Indicates encoded value is a UTF-8 string.
return value[1:].decode('UTF-8')
if value[0] == 2: # Indicates encoded value is binary data.
return value[1:]
raise ValueError('cannot decode value')
[docs]class SecretKey(dict):
"""
Data structure for representing all categories of secret key instances.
"""
_paillier_key_length = _PAILLIER_KEY_LENGTH
"""
Static parameter for Paillier cryptosystem (introduced in order to allow
modification in tests).
"""
[docs] @staticmethod
def generate(
cluster: dict = None,
operations: dict = None,
threshold: Optional[int] = None,
seed: Union[bytes, bytearray, str] = None
) -> SecretKey:
"""
Return a secret key built according to what is specified in the supplied
cluster configuration, operation specification, and other parameters.
>>> sk = SecretKey.generate({'nodes': [{}]}, {'sum': True})
>>> isinstance(sk, SecretKey)
True
Supplying an invalid combination of configurations and/or parameters raises
a corresponding exception.
>>> SecretKey.generate({'nodes': [{}]}, {'sum': True}, threshold='abc')
Traceback (most recent call last):
...
TypeError: threshold must be an integer
>>> SecretKey.generate({'nodes': [{}, {}]}, {'match': True}, threshold=1)
Traceback (most recent call last):
...
ValueError: thresholds are only supported for the sum operation
>>> SecretKey.generate({'nodes': [{}]}, {'sum': True}, threshold=1)
Traceback (most recent call last):
...
ValueError: thresholds are only supported for multiple-node clusters
>>> SecretKey.generate({'nodes': [{}]}, {'sum': True}, threshold=-1)
Traceback (most recent call last):
...
ValueError: threshold must a positive integer not larger than the cluster size
>>> SecretKey.generate({'nodes': [{}, {}]}, {'sum': True}, threshold=3)
Traceback (most recent call last):
...
ValueError: threshold must a positive integer not larger than the cluster size
>>> SecretKey.generate({'nodes': [{}]}, {'sum': True}, seed=bytes([123]))
Traceback (most recent call last):
...
ValueError: seed-based derivation of summation-compatible keys is not supported \
for single-node clusters
"""
# Normalize type of seed argument.
if isinstance(seed, str):
seed = seed.encode()
# Create instance with default cluster configuration and operations
# specification, updating the configuration and specification with the
# supplied arguments.
secret_key = SecretKey({
'material': {},
'cluster': cluster,
'operations': operations
})
if threshold is not None:
secret_key['threshold'] = threshold
if (
not isinstance(cluster, dict) or
'nodes' not in cluster or
not isinstance(cluster['nodes'], Sequence)
):
raise ValueError('valid cluster configuration is required')
cluster_size = len(secret_key['cluster']['nodes'])
if cluster_size < 1:
raise ValueError('cluster configuration must contain at least one node')
if (
(not isinstance(operations, dict)) or
(not set(operations.keys()).issubset({'store', 'match', 'sum'}))
):
raise ValueError('valid operations specification is required')
if len([op for (op, status) in secret_key['operations'].items() if status]) != 1:
raise ValueError('secret key must support exactly one operation')
if threshold is not None:
if not isinstance(threshold, int):
raise TypeError('threshold must be an integer')
if threshold < 1 or threshold > cluster_size:
raise ValueError(
'threshold must a positive integer not larger than the cluster size'
)
if cluster_size == 1:
raise ValueError(
'thresholds are only supported for multiple-node clusters'
)
if not secret_key['operations'].get('sum'):
raise ValueError(
'thresholds are only supported for the sum operation'
)
if secret_key['operations'].get('store'):
# Symmetric key for encrypting the plaintext or the shares of a plaintext.
secret_key['material'] = (
bcl.symmetric.secret()
if seed is None else
bytes.__new__(bcl.secret, _random_bytes(32, seed))
)
if secret_key['operations'].get('match'):
# Salt for deterministic hashing of the plaintext.
secret_key['material'] = _random_bytes(64, seed)
if secret_key['operations'].get('sum'):
if len(secret_key['cluster']['nodes']) == 1:
# Paillier secret key for encrypting a plaintext integer value.
if seed is not None:
raise ValueError(
'seed-based derivation of summation-compatible keys ' +
'is not supported for single-node clusters'
)
secret_key['material'] = pailliers.secret(SecretKey._paillier_key_length)
else:
# Distinct multiplicative mask for each additive share.
secret_key['material'] = [
_random_int(
1,
_SECRET_SHARED_SIGNED_INTEGER_MODULUS - 1,
(
_random_bytes(64, seed, i.to_bytes(64, 'little'))
if seed is not None else
None
)
)
for i in range(len(secret_key['cluster']['nodes']))
]
return secret_key
[docs] def dump(self: SecretKey) -> dict:
"""
Return a JSON-compatible dictionary representation of this key
instance.
>>> import json
>>> sk = SecretKey.generate({'nodes': [{}]}, {'store': True})
>>> isinstance(json.dumps(sk.dump()), str)
True
"""
dictionary = {
'material': {},
'cluster': self['cluster'],
'operations': self['operations'],
}
if 'threshold' in self:
dictionary['threshold'] = self['threshold']
if isinstance(self['material'], list):
# Additive secret sharing node-specific masks.
if all(isinstance(k, int) for k in self['material']):
dictionary['material'] = self['material']
elif isinstance(self['material'], (bytes, bytearray)):
dictionary['material'] = _pack(self['material'])
else:
# Secret key for Paillier encryption.
dictionary['material'] = {
'l': str(self['material'][0]),
'm': str(self['material'][1]),
'n': str(self['material'][2]),
'g': str(self['material'][3])
}
return dictionary
[docs] @staticmethod
def load(dictionary: dict) -> SecretKey:
"""
Return an instance built from a JSON-compatible dictionary
representation.
>>> sk = SecretKey.generate({'nodes': [{}]}, {'store': True})
>>> sk == SecretKey.load(sk.dump())
True
"""
secret_key = SecretKey({
'material': {},
'cluster': dictionary['cluster'],
'operations': dictionary['operations'],
})
if 'threshold' in dictionary:
secret_key['threshold'] = dictionary['threshold']
if isinstance(dictionary['material'], list):
# Additive secret sharing node-specific masks.
if all(isinstance(k, int) for k in dictionary['material']):
secret_key['material'] = dictionary['material']
elif isinstance(dictionary['material'], str):
secret_key['material'] = _unpack(dictionary['material'])
# If this is a secret symmetric key, ensure it has the
# expected type.
if 'store' in secret_key['operations']:
secret_key['material'] = bytes.__new__(
bcl.secret,
secret_key['material']
)
else:
# Secret key for Paillier encryption.
secret_key['material'] = tuple.__new__(
pailliers.secret,
(
int(dictionary['material']['l']),
int(dictionary['material']['m']),
int(dictionary['material']['n']),
int(dictionary['material']['g'])
)
)
return secret_key
[docs]class ClusterKey(SecretKey):
"""
Data structure for representing all categories of cluster key instances.
"""
[docs] @staticmethod
def generate( # pylint: disable=arguments-differ # Seeds not supported.
cluster: dict = None,
operations: dict = None,
threshold: Optional[int] = None
) -> ClusterKey:
"""
Return a cluster key built according to what is specified in the supplied
cluster configuration and operation specification.
>>> ck = ClusterKey.generate({'nodes': [{}, {}, {}]}, {'sum': True})
>>> isinstance(ck, ClusterKey)
True
Cluster keys can only be created for clusters that have two or more nodes.
>>> ClusterKey.generate({'nodes': [{}]}, {'store': True})
Traceback (most recent call last):
...
ValueError: cluster configuration must have at least two nodes
"""
# Create instance with default cluster configuration and operations
# specification, updating the configuration and specification with the
# supplied arguments.
cluster_key = ClusterKey(SecretKey.generate(cluster, operations, threshold))
if len(cluster_key['cluster']['nodes']) == 1:
raise ValueError('cluster configuration must have at least two nodes')
# Cluster keys contain no cryptographic material.
if 'material' in cluster_key:
del cluster_key['material']
return cluster_key
[docs] def dump(self: ClusterKey) -> dict:
"""
Return a JSON-compatible dictionary representation of this key
instance.
>>> import json
>>> cluster = {'nodes': [{}, {}, {}]}
>>> ck = ClusterKey.generate(cluster, {'sum': True}, threshold=2)
>>> isinstance(json.dumps(ck.dump()), str)
True
"""
dictionary = {
'cluster': self['cluster'],
'operations': self['operations']
}
if 'threshold' in self:
dictionary['threshold'] = self['threshold']
return dictionary
[docs] @staticmethod
def load(dictionary: dict) -> ClusterKey:
"""
Return an instance built from a JSON-compatible dictionary
representation.
>>> cluster = {'nodes': [{}, {}, {}]}
>>> ck = ClusterKey.generate(cluster, {'sum': True}, threshold=2)
>>> ck == ClusterKey.load(ck.dump())
True
"""
cluster_key = ClusterKey({
'cluster': dictionary['cluster'],
'operations': dictionary['operations'],
})
if 'threshold' in dictionary:
cluster_key['threshold'] = dictionary['threshold']
return cluster_key
[docs]class PublicKey(dict):
"""
Data structure for representing all categories of public key instances.
"""
[docs] @staticmethod
def generate(secret_key: SecretKey) -> PublicKey:
"""
Return a public key built according to what is specified in the supplied
secret key.
>>> sk = SecretKey.generate({'nodes': [{}]}, {'sum': True})
>>> isinstance(PublicKey.generate(sk), PublicKey)
True
"""
# Create instance with default cluster configuration and operations
# specification, updating the configuration and specification with the
# supplied arguments.
public_key = PublicKey({
'cluster': secret_key['cluster'],
'operations': secret_key['operations']
})
if isinstance(secret_key['material'], pailliers.secret):
public_key['material'] = pailliers.public(secret_key['material'])
else:
raise ValueError('cannot create public key for supplied secret key')
return public_key
[docs] def dump(self: PublicKey) -> dict:
"""
Return a JSON-compatible dictionary representation of this key
instance.
>>> import json
>>> sk = SecretKey.generate({'nodes': [{}]}, {'sum': True})
>>> pk = PublicKey.generate(sk)
>>> isinstance(json.dumps(pk.dump()), str)
True
"""
dictionary = {
'material': {},
'cluster': self['cluster'],
'operations': self['operations'],
}
# Public key for Paillier encryption.
dictionary['material'] = {
'n': str(self['material'][0]),
'g': str(self['material'][1])
}
return dictionary
[docs] @staticmethod
def load(dictionary: PublicKey) -> dict:
"""
Return an instance built from a JSON-compatible dictionary
representation.
>>> sk = SecretKey.generate({'nodes': [{}]}, {'sum': True})
>>> pk = PublicKey.generate(sk)
>>> pk == PublicKey.load(pk.dump())
True
"""
public_key = PublicKey({
'cluster': dictionary['cluster'],
'operations': dictionary['operations'],
})
# Public key for Paillier encryption.
public_key['material'] = tuple.__new__(
pailliers.public,
(
int(dictionary['material']['n']),
int(dictionary['material']['g'])
)
)
return public_key
[docs]def encrypt(
key: Union[SecretKey, PublicKey],
plaintext: Union[int, str, bytes]
) -> Union[str, Sequence[str], Sequence[int], Sequence[Sequence[int]]]:
"""
Return the ciphertext obtained by using the supplied key to encrypt the
supplied plaintext.
>>> key = SecretKey.generate({'nodes': [{}]}, {'store': True})
>>> isinstance(encrypt(key, 123), str)
True
Invocations that involve invalid argument values or types may raise an
exception.
>>> key = SecretKey.generate({'nodes': [{}]}, {'sum': True})
>>> encrypt(key, [])
Traceback (most recent call last):
...
TypeError: plaintext to encrypt for sum operation must be an integer
>>> encrypt(key, 2 ** 64)
Traceback (most recent call last):
...
ValueError: numeric plaintext must be a valid 32-bit signed integer
>>> del key['operations']['sum']
>>> encrypt(key, 123)
Traceback (most recent call last):
...
ValueError: cannot encrypt the supplied plaintext using the supplied key
"""
buffer = None
# Encode string or binary data for storage or matching.
if isinstance(plaintext, (str, bytes)):
buffer = _encode(plaintext)
if len(buffer) > _PLAINTEXT_STRING_BUFFER_LEN_MAX + 1:
raise ValueError(
'string or binary plaintext must be possible to encode in ' +
str(_PLAINTEXT_STRING_BUFFER_LEN_MAX) +
' bytes or fewer'
)
# Encode integer data for storage or matching.
if isinstance(plaintext, int):
# Only 32-bit signed integer plaintexts are supported.
if (
plaintext < _PLAINTEXT_SIGNED_INTEGER_MIN or
plaintext >= _PLAINTEXT_SIGNED_INTEGER_MAX
):
raise ValueError('numeric plaintext must be a valid 32-bit signed integer')
# Encode an integer for storage or matching.
buffer = _encode(plaintext)
# Encrypt a plaintext for storage and retrieval.
if key['operations'].get('store'):
# For single-node clusters, the data is encrypted using a symmetric key.
if len(key['cluster']['nodes']) == 1:
return _pack(
bcl.symmetric.encrypt(key['material'], bcl.plain(buffer))
)
# For multiple-node clusters, the ciphertext is secret-shared using XOR
# (with each share symmetrically encrypted in the case of a secret key).
optional_enc = (
(lambda s: bcl.symmetric.encrypt(key['material'], bcl.plain(s)))
if 'material' in key else
(lambda s: s)
)
shares = []
aggregate = bytes(len(buffer))
for _ in range(len(key['cluster']['nodes']) - 1):
mask = _random_bytes(len(buffer))
aggregate = bytes(a ^ b for (a, b) in zip(aggregate, mask))
shares.append(optional_enc(mask))
shares.append(optional_enc(
bytes(a ^ b for (a, b) in zip(aggregate, buffer))
))
return list(map(_pack, shares))
# Encrypt (i.e., hash) a plaintext for matching.
if key['operations'].get('match'):
# The deterministic salted hash of the encoded plaintext is the ciphertext.
ciphertext = _pack(_HASH(key['material'] + buffer).digest())
# For multiple-node clusters, replicate the ciphertext for each node.
if len(key['cluster']['nodes']) > 1:
ciphertext = [ciphertext for _ in key['cluster']['nodes']]
return ciphertext
# Encrypt an integer plaintext in a summation-compatible way.
if key['operations'].get('sum'):
# Non-integer cannot be encrypted for summation.
if not isinstance(plaintext, int):
raise TypeError('plaintext to encrypt for sum operation must be an integer')
# For single-node clusters, the Paillier cryptosystem is used.
if len(key['cluster']['nodes']) == 1:
return hex(pailliers.encrypt(key['material'], plaintext))[2:] # No '0x'.
# For multiple-node clusters and no threshold, additive secret sharing is used.
if 'threshold' not in key:
masks = [
key['material'][i] if 'material' in key else 1
for i in range(len(key['cluster']['nodes']))
]
shares = []
total = 0
quantity = len(key['cluster']['nodes'])
for i in range(quantity - 1):
share_ = _random_int(0, _SECRET_SHARED_SIGNED_INTEGER_MODULUS - 1)
shares.append(
(masks[i] * share_)
%
_SECRET_SHARED_SIGNED_INTEGER_MODULUS
)
total = (total + share_) % _SECRET_SHARED_SIGNED_INTEGER_MODULUS
shares.append(
(
masks[quantity - 1] *
((plaintext - total) % _SECRET_SHARED_SIGNED_INTEGER_MODULUS)
) % _SECRET_SHARED_SIGNED_INTEGER_MODULUS
)
return shares
# For multiple-node clusters and a threshold, Shamir's secret sharing is used.
masks = [
key['material'][i] if 'material' in key else 1
for i in range(len(key['cluster']['nodes']))
]
num_nodes = len(key['cluster']['nodes'])
shares = _shamirs_shares(plaintext, num_nodes, key['threshold'])
for (i, share) in enumerate(shares):
share[1] = (masks[i] * share[1]) % _SECRET_SHARED_SIGNED_INTEGER_MODULUS
return shares
# The below should not occur unless the key's cluster or operations
# information is malformed/missing or the plaintext is unsupported.
raise ValueError('cannot encrypt the supplied plaintext using the supplied key')
[docs]def decrypt(
key: SecretKey,
ciphertext: Union[str, Sequence[str], Sequence[int], Sequence[Sequence[int]]]
) -> Union[int, str, bytes]:
"""
Return the plaintext obtained by using the supplied key to decrypt the
supplied ciphertext.
>>> key = SecretKey.generate({'nodes': [{}, {}]}, {'store': True})
>>> decrypt(key, encrypt(key, 123))
123
>>> key = SecretKey.generate({'nodes': [{}, {}]}, {'store': True})
>>> decrypt(key, encrypt(key, -10))
-10
>>> key = SecretKey.generate({'nodes': [{}, {}]}, {'store': True})
>>> decrypt(key, encrypt(key, bytes([1, 2, 3])))
b'\\x01\\x02\\x03'
>>> key = SecretKey.generate({'nodes': [{}]}, {'store': True})
>>> decrypt(key, encrypt(key, 'abc'))
'abc'
>>> key = SecretKey.generate({'nodes': [{}]}, {'store': True})
>>> decrypt(key, encrypt(key, 123))
123
>>> key = SecretKey.generate({'nodes': [{}, {}]}, {'sum': True})
>>> decrypt(key, encrypt(key, 123))
123
>>> key = SecretKey.generate({'nodes': [{}, {}]}, {'sum': True})
>>> decrypt(key, encrypt(key, -10))
-10
>>> key = SecretKey.generate({'nodes': [{}, {}]}, {'sum': True}, threshold=2)
>>> decrypt(key, encrypt(key, 123))
123
>>> key = SecretKey.generate({'nodes': [{}, {}, {}, {}]}, {'sum': True}, threshold=3)
>>> decrypt(key, encrypt(key, 123)[:-1])
123
>>> key = SecretKey.generate({'nodes': [{}, {}, {}, {}]}, {'sum': True}, threshold=2)
>>> decrypt(key, encrypt(key, 123)[2:])
123
>>> key = SecretKey.generate({'nodes': [{}, {}]}, {'sum': True}, threshold=1)
>>> decrypt(key, encrypt(key, 123)[1:])
123
>>> key = SecretKey.generate({'nodes': [{}, {}]}, {'sum': True}, threshold=2)
>>> decrypt(key, encrypt(key, -10))
-10
An exception is raised if a ciphertext cannot be decrypted using the
supplied key (*e.g.*, because one or both are malformed or they are
incompatible).
>>> key = SecretKey.generate({'nodes': [{}, {}]}, {'store': True})
>>> decrypt(key, 'abc')
Traceback (most recent call last):
...
ValueError: secret key requires a valid ciphertext from a multiple-node cluster
>>> decrypt(
... SecretKey({'cluster': {'nodes': [{}]}, 'operations': {}}),
... 'abc'
... )
Traceback (most recent call last):
...
ValueError: cannot decrypt the supplied ciphertext using the supplied key
>>> key_alt = SecretKey.generate({'nodes': [{}, {}]}, {'store': True})
>>> decrypt(key_alt, encrypt(key, 123))
Traceback (most recent call last):
...
ValueError: cannot decrypt the supplied ciphertext using the supplied key
"""
error = ValueError(
'cannot decrypt the supplied ciphertext using the supplied key'
)
# Confirm that the secret key and ciphertext have compatible cluster
# specifications.
if len(key['cluster']['nodes']) == 1:
if not isinstance(ciphertext, str):
raise ValueError(
'secret key requires a valid ciphertext from a single-node cluster'
)
else:
if (
isinstance(ciphertext, str) or # Must be a container sequence.
(not isinstance(ciphertext, Sequence)) or
(not (
all(
(
isinstance(c, Sequence) and
len(c) == 2 and
all(isinstance(x, int) for x in c)
)
for c in ciphertext
) or
all(isinstance(c, int) for c in ciphertext) or
all(isinstance(c, str) for c in ciphertext)
))
):
raise ValueError(
'secret key requires a valid ciphertext from a multiple-node cluster'
)
if (
isinstance(ciphertext, Sequence) and
len(ciphertext) < (
key['threshold']
if 'threshold' in key else
len(key['cluster']['nodes'])
)
):
raise ValueError(
'ciphertext must have enough shares for cluster size or threshold'
)
# Decrypt a value that was encrypted for storage and retrieval.
if key['operations'].get('store'):
# For single-node clusters, the plaintext is encrypted using a symmetric key.
if len(key['cluster']['nodes']) == 1:
try:
return _decode(
bcl.symmetric.decrypt(
key['material'],
bcl.cipher(_unpack(ciphertext))
)
)
except Exception as exc:
raise error from exc
# For multiple-node clusters, the ciphertext is secret-shared using XOR
# (with each share symmetrically encrypted in the case of a secret key).
shares = [_unpack(share) for share in ciphertext]
if 'material' in key:
try:
shares = [
bcl.symmetric.decrypt(key['material'], bcl.cipher(share))
for share in shares
]
except Exception as exc:
raise error from exc
bytes_ = bytes(len(shares[0]))
for share_ in shares:
bytes_ = bytes(a ^ b for (a, b) in zip(bytes_, share_))
return _decode(bytes_)
# Decrypt a value that was encrypted in a summation-compatible way.
if key['operations'].get('sum'):
# For single-node clusters, the Paillier cryptosystem is used.
if len(key['cluster']['nodes']) == 1:
return pailliers.decrypt(
key['material'],
pailliers.cipher(int(ciphertext, 16))
)
# For multiple-node clusters and no threshold, additive secret sharing is used.
if 'threshold' not in key:
inverse_masks = [
pow(
key['material'][i] if 'material' in key else 1,
_SECRET_SHARED_SIGNED_INTEGER_MODULUS - 2,
_SECRET_SHARED_SIGNED_INTEGER_MODULUS
)
for i in range(len(key['cluster']['nodes']))
]
shares = ciphertext
plaintext = 0
for (i, share_) in enumerate(shares):
plaintext = (
plaintext +
((inverse_masks[i] * share_) % _SECRET_SHARED_SIGNED_INTEGER_MODULUS)
) % _SECRET_SHARED_SIGNED_INTEGER_MODULUS
# Field elements in the "upper half" of the field represent negative
# integers.
if plaintext > _PLAINTEXT_SIGNED_INTEGER_MAX:
plaintext -= _SECRET_SHARED_SIGNED_INTEGER_MODULUS
return plaintext
# For multiple-node clusters and a threshold, Shamir's secret sharing is used.
inverse_masks = [
pow(
key['material'][i] if 'material' in key else 1,
_SECRET_SHARED_SIGNED_INTEGER_MODULUS - 2,
_SECRET_SHARED_SIGNED_INTEGER_MODULUS
)
for i in range(len(key['cluster']['nodes']))
]
shares = ciphertext
for (i, share) in enumerate(shares):
share[1] = (
inverse_masks[share[0] - 1] * shares[i][1]
) % _SECRET_SHARED_SIGNED_INTEGER_MODULUS
plaintext = _shamirs_recover(shares)
# Field elements in the "upper half" of the field represent negative
# integers.
if plaintext > _PLAINTEXT_SIGNED_INTEGER_MAX:
plaintext -= _SECRET_SHARED_SIGNED_INTEGER_MODULUS
return plaintext
raise error
[docs]def allot(
document: Union[int, bool, str, list, dict]
) -> Sequence[Union[int, bool, str, list, dict]]:
"""
Convert a document that may contain ciphertexts intended for multiple-node
clusters into secret shares of that document. Shallow copies are created
whenever possible.
>>> d = {
... 'id': 0,
... 'age': {'%allot': [1, 2, 3]},
... 'dat': {'loc': {'%allot': [4, 5, 6]}}
... }
>>> for d in allot(d): print(d)
{'id': 0, 'age': {'%share': 1}, 'dat': {'loc': {'%share': 4}}}
{'id': 0, 'age': {'%share': 2}, 'dat': {'loc': {'%share': 5}}}
{'id': 0, 'age': {'%share': 3}, 'dat': {'loc': {'%share': 6}}}
A document with no ciphertexts intended for decentralized clusters is
unmodofied; a list containing this document is returned.
>>> allot({'id': 0, 'age': 23})
[{'id': 0, 'age': 23}]
Any attempt to convert a document that has an incorrect structure raises
an exception.
>>> allot({1, 2, 3})
Traceback (most recent call last):
...
TypeError: boolean, integer, float, string, list, dictionary, or None expected
>>> allot({'id': 0, 'age': {'%allot': [1, 2, 3], 'extra': [1, 2, 3]}})
Traceback (most recent call last):
...
ValueError: allotment must only have one key
>>> allot({
... 'id': 0,
... 'age': {'%allot': [1, 2, 3]},
... 'dat': {'loc': {'%allot': [4, 5]}}
... })
Traceback (most recent call last):
...
ValueError: number of shares in subdocument is not consistent
>>> allot([
... 0,
... {'%allot': [1, 2, 3]},
... {'loc': {'%allot': [4, 5]}}
... ])
Traceback (most recent call last):
...
ValueError: number of shares in subdocument is not consistent
"""
# Values and ``None`` are base cases; return a single share.
if isinstance(document, (bool, int, float, str)) or document is None:
return [document]
if isinstance(document, list):
results = list(map(allot, document))
# Determine the number of shares that must be created.
multiplicity = 1
for result in results:
if len(result) != 1:
if multiplicity == 1:
multiplicity = len(result)
elif multiplicity != len(result):
raise ValueError(
'number of shares in subdocument is not consistent'
)
# Create and return the appropriate number of shares.
shares = []
for i in range(multiplicity):
share = []
for result in results:
share.append(result[0 if len(result) == 1 else i])
shares.append(share)
return shares
if isinstance(document, dict):
# Document contains shares obtained from the ``encrypt`` function
# that must be allotted to nodes.
if '%allot' in document:
if len(document.keys()) != 1:
raise ValueError('allotment must only have one key')
items = document['%allot']
if isinstance(items, list):
# Simple allotment.
if (
all(isinstance(item, int) for item in items) or
all(isinstance(item, str) for item in items)
):
return [{'%share': item} for item in document['%allot']]
# More complex allotment with nested lists of shares.
return [
{'%share': [share['%share'] for share in shares]}
for shares in allot([{'%allot': item} for item in items])
]
# Document is a general-purpose key-value mapping.
results = {}
multiplicity = 1
for key in document:
result = allot(document[key])
results[key] = result
if len(result) != 1:
if multiplicity == 1:
multiplicity = len(result)
elif multiplicity != len(result):
raise ValueError(
'number of shares in subdocument is not consistent'
)
# Create the appropriate number of document shares.
shares = []
for i in range(multiplicity):
share = {}
for key in results:
results_for_key = results[key]
share[key] = results_for_key[0 if len(results_for_key) == 1 else i]
shares.append(share)
return shares
raise TypeError(
'boolean, integer, float, string, list, dictionary, or None expected'
)
[docs]def unify(
secret_key: SecretKey,
documents: Sequence[Union[int, bool, str, list, dict]],
ignore: Sequence[str] = None
) -> Union[int, bool, str, list, dict]:
"""
Convert an object that may contain ciphertexts intended for multiple-node
clusters into secret shares of that object. Shallow copies are created
whenever possible.
>>> data = {
... 'a': [True, 'v', 12],
... 'b': [False, 'w', 34],
... 'c': [True, 'x', 56],
... 'd': [False, 'y', 78],
... 'e': [True, 'z', 90],
... }
>>> sk = SecretKey.generate({'nodes': [{}, {}, {}]}, {'store': True})
>>> encrypted = {
... 'a': [True, 'v', {'%allot': encrypt(sk, 12)}],
... 'b': [False, 'w', {'%allot': encrypt(sk, 34)}],
... 'c': [True, 'x', {'%allot': encrypt(sk, 56)}],
... 'd': [False, 'y', {'%allot': encrypt(sk, 78)}],
... 'e': [True, 'z', {'%allot': encrypt(sk, 90)}],
... }
>>> shares = allot(encrypted)
>>> decrypted = unify(sk, shares)
>>> data == decrypted
True
It is possible to wrap nested lists of shares to reduce the overhead
associated with the ``{'%allot': ...}`` and ``{'%share': ...}`` wrappers.
>>> data = {
... 'a': [1, [2, 3]],
... 'b': [4, 5, 6],
... 'c': None,
... 'd': 1.23
... }
>>> sk = SecretKey.generate({'nodes': [{}, {}, {}]}, {'store': True})
>>> encrypted = {
... 'a': {'%allot': [encrypt(sk, 1), [encrypt(sk, 2), encrypt(sk, 3)]]},
... 'b': {'%allot': [encrypt(sk, 4), encrypt(sk, 5), encrypt(sk, 6)]},
... 'c': None,
... 'd': 1.23
... }
>>> shares = allot(encrypted)
>>> decrypted = unify(sk, shares)
>>> data == decrypted
True
The ``ignore`` parameter specifies which keys should be ignored during
unification. By default, ``'_created'`` and ``'_updated'`` are ignored.
>>> shares[0]['_created'] = '123'
>>> shares[1]['_created'] = '456'
>>> shares[2]['_created'] = '789'
>>> shares[0]['_updated'] = 'ABC'
>>> shares[1]['_updated'] = 'DEF'
>>> shares[2]['_updated'] = 'GHI'
>>> decrypted = unify(sk, shares)
>>> data == decrypted
True
Unification returns the sole document when a one-document list is supplied.
>>> 123 == unify(sk, [123])
True
Any attempt to supply incompatible document shares raises an exception.
>>> unify(sk, [123, 'abc'])
Traceback (most recent call last):
...
TypeError: array of compatible document shares expected
"""
if ignore is None:
ignore = ['_created', '_updated']
if len(documents) == 1:
return documents[0]
if all(isinstance(document, list) for document in documents):
length = len(documents[0])
if all(len(document) == length for document in documents[1:]):
return [
unify(secret_key, [share[i] for share in documents], ignore)
for i in range(length)
]
if all(isinstance(document, dict) for document in documents):
# Documents are shares.
if all('%share' in document for document in documents):
# Simple document shares.
if (
all(isinstance(d['%share'], int) for d in documents) or
all(isinstance(d['%share'], str) for d in documents)
):
return decrypt(
secret_key,
[document['%share'] for document in documents]
)
# Document shares consisting of nested lists of shares.
return [
unify(
secret_key,
[{'%share': share} for share in shares],
ignore
)
for shares in zip(*[document['%share'] for document in documents])
]
# Documents are general-purpose key-value mappings.
keys = documents[0].keys()
if all(document.keys() == keys for document in documents[1:]):
# For ignored keys, unification is not performed and
# they are omitted from the results.
keys = [key for key in keys if key not in ignore]
results = {}
for key in keys:
results[key] = unify(
secret_key,
[document[key] for document in documents],
ignore
)
return results
# Base case: all documents must be equivalent.
all_values_equal = True
for i in range(1, len(documents)):
all_values_equal &= documents[0] == documents[i]
if all_values_equal:
return documents[0]
raise TypeError('array of compatible document shares expected')
if __name__ == '__main__':
doctest.testmod() # pragma: no cover