199 lines
4.8 KiB
Python
199 lines
4.8 KiB
Python
# coding: utf-8
|
|
from __future__ import unicode_literals, division, absolute_import, print_function
|
|
|
|
import sys
|
|
import hashlib
|
|
import math
|
|
|
|
from ._asn1 import int_from_bytes, int_to_bytes
|
|
from ._errors import pretty_message
|
|
from ._types import type_name, byte_cls, int_types
|
|
|
|
|
|
if sys.version_info < (3,):
|
|
chr_cls = chr
|
|
|
|
else:
|
|
def chr_cls(num):
|
|
return bytes([num])
|
|
|
|
|
|
__all__ = [
|
|
'pkcs12_kdf',
|
|
]
|
|
|
|
|
|
def pkcs12_kdf(hash_algorithm, password, salt, iterations, key_length, id_):
|
|
"""
|
|
KDF from RFC7292 appendix b.2 - https://tools.ietf.org/html/rfc7292#page-19
|
|
|
|
:param hash_algorithm:
|
|
The string name of the hash algorithm to use: "md5", "sha1", "sha224",
|
|
"sha256", "sha384", "sha512"
|
|
|
|
:param password:
|
|
A byte string of the password to use an input to the KDF
|
|
|
|
:param salt:
|
|
A cryptographic random byte string
|
|
|
|
:param iterations:
|
|
The numbers of iterations to use when deriving the key
|
|
|
|
:param key_length:
|
|
The length of the desired key in bytes
|
|
|
|
:param id_:
|
|
The ID of the usage - 1 for key, 2 for iv, 3 for mac
|
|
|
|
:return:
|
|
The derived key as a byte string
|
|
"""
|
|
|
|
if not isinstance(password, byte_cls):
|
|
raise TypeError(pretty_message(
|
|
'''
|
|
password must be a byte string, not %s
|
|
''',
|
|
type_name(password)
|
|
))
|
|
|
|
if not isinstance(salt, byte_cls):
|
|
raise TypeError(pretty_message(
|
|
'''
|
|
salt must be a byte string, not %s
|
|
''',
|
|
type_name(salt)
|
|
))
|
|
|
|
if not isinstance(iterations, int_types):
|
|
raise TypeError(pretty_message(
|
|
'''
|
|
iterations must be an integer, not %s
|
|
''',
|
|
type_name(iterations)
|
|
))
|
|
|
|
if iterations < 1:
|
|
raise ValueError(pretty_message(
|
|
'''
|
|
iterations must be greater than 0 - is %s
|
|
''',
|
|
repr(iterations)
|
|
))
|
|
|
|
if not isinstance(key_length, int_types):
|
|
raise TypeError(pretty_message(
|
|
'''
|
|
key_length must be an integer, not %s
|
|
''',
|
|
type_name(key_length)
|
|
))
|
|
|
|
if key_length < 1:
|
|
raise ValueError(pretty_message(
|
|
'''
|
|
key_length must be greater than 0 - is %s
|
|
''',
|
|
repr(key_length)
|
|
))
|
|
|
|
if hash_algorithm not in set(['md5', 'sha1', 'sha224', 'sha256', 'sha384', 'sha512']):
|
|
raise ValueError(pretty_message(
|
|
'''
|
|
hash_algorithm must be one of "md5", "sha1", "sha224", "sha256",
|
|
"sha384", "sha512", not %s
|
|
''',
|
|
repr(hash_algorithm)
|
|
))
|
|
|
|
if id_ not in set([1, 2, 3]):
|
|
raise ValueError(pretty_message(
|
|
'''
|
|
id_ must be one of 1, 2, 3, not %s
|
|
''',
|
|
repr(id_)
|
|
))
|
|
|
|
utf16_password = password.decode('utf-8').encode('utf-16be') + b'\x00\x00'
|
|
|
|
algo = getattr(hashlib, hash_algorithm)
|
|
|
|
# u and v values are bytes (not bits as in the RFC)
|
|
u = {
|
|
'md5': 16,
|
|
'sha1': 20,
|
|
'sha224': 28,
|
|
'sha256': 32,
|
|
'sha384': 48,
|
|
'sha512': 64
|
|
}[hash_algorithm]
|
|
|
|
if hash_algorithm in ['sha384', 'sha512']:
|
|
v = 128
|
|
else:
|
|
v = 64
|
|
|
|
# Step 1
|
|
d = chr_cls(id_) * v
|
|
|
|
# Step 2
|
|
s = b''
|
|
if salt != b'':
|
|
s_len = v * int(math.ceil(float(len(salt)) / v))
|
|
while len(s) < s_len:
|
|
s += salt
|
|
s = s[0:s_len]
|
|
|
|
# Step 3
|
|
p = b''
|
|
if utf16_password != b'':
|
|
p_len = v * int(math.ceil(float(len(utf16_password)) / v))
|
|
while len(p) < p_len:
|
|
p += utf16_password
|
|
p = p[0:p_len]
|
|
|
|
# Step 4
|
|
i = s + p
|
|
|
|
# Step 5
|
|
c = int(math.ceil(float(key_length) / u))
|
|
|
|
a = b'\x00' * (c * u)
|
|
|
|
for num in range(1, c + 1):
|
|
# Step 6A
|
|
a2 = algo(d + i).digest()
|
|
for _ in range(2, iterations + 1):
|
|
a2 = algo(a2).digest()
|
|
|
|
if num < c:
|
|
# Step 6B
|
|
b = b''
|
|
while len(b) < v:
|
|
b += a2
|
|
|
|
b = int_from_bytes(b[0:v]) + 1
|
|
|
|
# Step 6C
|
|
for num2 in range(0, len(i) // v):
|
|
start = num2 * v
|
|
end = (num2 + 1) * v
|
|
i_num2 = i[start:end]
|
|
|
|
i_num2 = int_to_bytes(int_from_bytes(i_num2) + b)
|
|
|
|
# Ensure the new slice is the right size
|
|
i_num2_l = len(i_num2)
|
|
if i_num2_l > v:
|
|
i_num2 = i_num2[i_num2_l - v:]
|
|
|
|
i = i[0:start] + i_num2 + i[end:]
|
|
|
|
# Step 7 (one piece at a time)
|
|
begin = (num - 1) * u
|
|
to_copy = min(key_length, u)
|
|
a = a[0:begin] + a2[0:to_copy] + a[begin + to_copy:]
|
|
|
|
return a[0:key_length]
|