魂の生命の領域

AWS とか Python とか本読んだ感想とか哲学とか書きます

Cognito の USER_SRP_AUTH を Python で理解したい

Cognito の USER_SRP_AUTH とは

Cognito には、一般にイメージするユーザー名とパスワードを直接サーバーに送信する認証フロー以外にもいくつか種類がありそのうちの一つが USER_SRP_AUTH と呼ばれるものです。

簡単にいうと、ユーザー名とパスワードを直接サーバーに送信するのではなく、それらを特定の計算式に基づいてハッシュ化して送信します。

流れとしては、以下のように 2 段階になります。(※ MFA が設定されていると最後にもう1段階追加されます)

  1. クライアント: AuthFlow に USER_SRP_AUTH を指定して InitiateAuth API を実行
  2. Cognito: チャレンジを要求
  3. クライアント: RespondToAuthChallenge API を実行
  4. Cognito: トークンを返却して認証が成功

ユーザーのパスワードは初回ログイン後に変更する必要がありますが、それを済ませた場合はこのようになります。

  1. クライアント: InitiateAuth API を実行
  2. Cognito: PASSWORD_VERIFIER チャレンジを要求
  3. クライアント: AuthFlow に PASSWORD_VERIFIER を指定して RespondToAuthChallenge API を実行
  4. Cognito: トークンを返却して認証が成功

今回はこの 1 と 3 を Python で実装したいというお話になります。

ソースコード

自分で書いたコードはこれです。とは言え後で述べるように参考元を読み解きながら写経している形になります。ただしコピペではないです。

github.com

ここまでの参考情報

  • ドキュメント

docs.aws.amazon.com

  • Black Belt の P.33 ~ P.35 の図がわかりやすいです。

https://d1.awsstatic.com/webinars/jp/pdf/services/20200630_AWS_BlackBelt_Amazon%20Cognito.pdf

なぜ Python で頑張るのか

前述のドキュメントからもわかるように、ブラウザ用には JavaScriptSDK があり、モバイル用には AndroidiOS 向けに SDK が用意されています。なので基本的にはこれを使ってやればいいことになります。

だけどクライアントのアプリが他の言語で動いていた場合は困りますよね。

私は Python が一番得意なので Python を想定していますが、例えば Django 製のアプリだと Python の実装が必要そうです。

元ネタ

ネタは Amplify SDK に組み込まれたライブラリに実装があります。

github.com

そしてこれを忠実に再現した Python 版があります。

github.com

ただしメンテの具合とかサポートしていると明言している Python のバージョンからもわかるようにプロダクトで使うのはちょっと躊躇うというかなんというか。

README だけ見ると Python 3.6 までしかサポートしてない風なのですが、 settings.py を見ると Python 3.8 までは想定しているみたいなんですよね。メンテしてくれ。

なので、自前で実装できたら嬉しいなぁという思いで Python での実装を追いかけてみようというお話です。

※注意

この記事自体手元のお勉強メモをつなぎ合わせたもので、自分用の文章として元々書いていたため他の人が読むと思考の流れが破茶滅茶な可能性があります。

RFC あるよ

SRP(Secure Remote Password)プロトコルRFC には計算式が載っています。やっていることは多分これの通りだと思います。

https://datatracker.ietf.org/doc/html/rfc2945

式だけならこっちの方が見やすいです。

https://datatracker.ietf.org/doc/html/rfc5054#section-2.6

この記事内だけの前提

MFA 認証と client_secret については省略します。(それがまず何なのかの説明やらも含め省略させてください)

全体像

SRP のプロトコルそのものと Cognito 特有の処理を全部突っ込みました。(直前に省くと明言した部分を除く)

色分けは以下のルールに従います。

  • オレンジ色の文字はサーバー(要するに Cognito )に直接送信するパラメータ
  • 赤文字はサーバーから返却されるパラメータ
  • グレーの枠は API 呼び出し

クライアント側での計算を概観

まずクライアント側での計算を概観します。RFC からの抜粋したものがわかりやすいです。

I, P = <read from user>
N, g, s, B = <read from server>
a = random()
A = g^a % N
u = SHA1(PAD(A) | PAD(B))
k = SHA1(N | PAD(g))
x = SHA1(s | SHA1(I | ":" | P))
<premaster secret> = (B - (k * g^x)) ^ (a + (u * x)) % N

ちなみに ^ は冪乗、% は剰余、 | は文字列の結合です。 Python だと ^ は XOR、 | は OR なのでご注意ください。

そして、Cognito では SHA1 ではなく SHA256 が使われています。そっちの方が推奨されてるからね。

また、Cognito 用に補足すると、

userpool_id = ap-northeast-1_xxxx
username = user

としたとき、 Ixxxxuser にあたります。 これによってユーザープール内でユーザー名が一意であれば OK、みたいな実装になっているのかなと、そんな気がします。

サーバー側での計算を概観

これも RFC から抜粋です。

N, g, s, v = <read from password file>
b = random()
k = SHA1(N | PAD(g))
B = k*v + g^b % N
A = <read from client>
u = SHA1(PAD(A) | PAD(B))
<premaster secret> = (A * v^u) ^ b % N

サーバー側のロジックは見えないけど、こちらも SHA1 ではなく SHA256 が使われているんだろうと思います。

全体の流れ

流れは冒頭にも貼ったドキュメントの通りこんな感じです。

  1. クライアントは InitiateAuth API を実行
  2. Cognito はチャレンジを要求
  3. クライアントは RespondToAuthChallenge API を実行
  4. Cognito がトークンを返却して認証が成功

docs.aws.amazon.com

そして今回は 2. のチャレンジとして PASSWORD_VERIFIER のみを想定することにします。

0. 多用される計算

一連の流れで多様されるユーティリティ的な関数を先にまとめておきましょう。

  • 16進数と整数(10進数)の間の相互変換
def hex_to_long(hex_string):
    return int(hex_string, 16)

def long_to_hex(long_num):
    return '%x' % long_num
  • ハッシュ化の関数
    • バイト列を SHA256 でハッシュ化します。
    • 64文字に満たない場合は先頭を '0' で埋めます。
def hash_sha256(buf):
    a = hashlib.sha256(buf).hexdigest()
    return (64 - len(a)) * '0' + a

https://docs.python.org/ja/3/library/hashlib.html

  • hex_hash(hex_string)
    • 先ほどの hash_sha256 関数を使って、16進数の文字列をバイト列に変換してからハッシュ化します。
hex_hash = hash_sha256(bytearray.fromhex(hex_string))
  • pad_hex(long_int)
    • RFC で度々出てくる PAD() 関数の実装です。
def pad_hex(long_int):
    if not isinstance(long_int, six.string_types):
        hash_str = long_to_hex(long_int)
    else:
        hash_str = long_int
    if len(hash_str) % 2 == 1:
        hash_str = '0%s' % hash_str
    elif hash_str[0] in '89ABCDEFabcdef':
        hash_str = '00%s' % hash_str
    return hash_str

やっていることは(見たままですが)以下の通りです。

  • long_int が文字列であればそのまま hash_str にセット
    • 文字列でない場合は long_to_hex で文字列にしてから hash_str にセット
  • hash_str の長さが奇数であれば、先頭に '0' を追加
  • hash_str の長さが偶数かつ先頭が '89ABCDEFabcdef' のいずれかであれば、先頭に '00' を追加
    • それ以外はそのまま

どんな意味があるのかはまだよくわかりません!

1. SRP_A を算出する

では早速計算していきます。

まず初めに RFC でいうところの A を計算します。

A = g^a % N

実装はそのままです。 pow(a, b, c)a ^ b % cPython の文法では a ** b % c )なのでそれを使います。

def calculate_a(self):
    """
    A = g^a % N を計算する
    """
    A = pow(self.g, self.small_a_value, self.N)
    if A % self.N == 0:
        raise ValueError('Illegal paramater. A mod N cannot be 0.')
 
    return A

A が N で割り切れてしまうとエラーにする点に注意です。

ここで、 Cognito ではそれぞれ以下の値が設定されています。

  • g = 2
  • a = 128 byteのランダム文字列を16進数に変換したあと、さらに10進数に変換し N で割った剰余
  • N = ライブラリ側で定まった大きな値(16進数)を10進数に変換したもの

a の具体的な実装を見ていきます。

a = generate_random_small_a()

def generate_random_small_a():
    random_long_int = get_random(128)
    return random_long_int % big_n

def get_random(nbytes):
    random_hex = binascii.hexlify(os.urandom(nbytes))
    return hex_to_long(random_hex)

それぞれこんな感じの挙動です。

>>> a = os.urandom(4)
>>> a
b'\xd9Uo\xfe'

>>> binascii.hexlify(a)
b'd9556ffe'

>>> binascii.hexlify(b'U')
b'55'
>>> binascii.hexlify(b'o')
b'6f'

\xd9 U o \xfe みたいな感じで、数字の部分は16進数で表記されており、その全体をさらに16進数表記 d9 55 6f fe にしているようです。

その上で、 hex_to_long(random_hex) を使って10進数の int 型にしているんですね〜

また、 N は 4. 3072-bit Group と同じ値です。

2. InitiateAuth API を実行

SRP_A の値が計算できましたので、それを使って InitiateAuth API を実行します。

client.initiate_auth(
    AuthFlow='USER_SRP_AUTH',
    AuthParameters={
        'USERNAME': 'demo-user',
        'SRP_A': <1.で計算した値を long_to_hex で変換したもの>
    },
    ClientId=client_id
)

この時点ではまだ自分のユーザー名しか教えていない点にご注意ください。

3. Challenge Response の生成

InitiateAuth API のレスポンスを見ると、ChallengeName の値が PASSWORD_VERIFIER です。 PASSWORD_VERIFIER_CHALLENGE を要求されている、ということになります。

具体的なレスポンス形式は以下の通りです。 ChallengeParameters に含まれる値を使ってユーザー名とパスワードをハッシュ化した値を送信し、自分がちゃんとしたユーザーだと証明してね!と要求されているわけです。

{
    'ChallengeName': 'PASSWORD_VERIFIER',
    'ChallengeParameters': {
        'SALT': '3b9cadfa7530456cc432931b15bf9951', # 定期的に変わるのかわからないけど、毎回変わるわけではなさそう
        'SECRET_BLOCK': 'xxxxx', # 長い base64 エンコードされた文字列
        'SRP_B': 'xxxxx', # 長い16進数文字列
        'USERNAME': 'demo-user',
        'USER_ID_FOR_SRP': 'demo-user'
    }
}

それでは、ここからそのハッシュ化の流れを追いかけていきます。

3-1. timestamp の生成

まず、前準備として timestamp を生成します。だいぶ後の方で使います。

Cognito の仕様として、日(day)が一桁の場合は十の位のゼロを省きます。

>>> import re
>>> import datetime
>>> t = datetime.datetime.utcnow().strftime("%a %b %d %H:%M:%S UTC %Y")
>>> t
'Sat Apr 09 04:21:27 UTC 2022'
>>> re.sub(r" 0(\d) ", r" \1 ", t)
'Sat Apr 9 04:21:27 UTC 2022'

使われている正規表現について少し補足しておきます。

re.sub(r" 0(\d) ", r" \1 ", t)
re.sub(<置き換え対象の正規表現>, <置き換え後の正規表現>, 対象の文字列)
r' 0(\d) ' → 半角スペース + 0 + 任意の10進整数(一桁) + 半角スペース
r' \1 ' → 半角スペース + 任意の10進整数 + 半角スペース

3-2. hkdf の算出

ハッシュ化に際して、ハッシュ化に用いる hkdf (HMAC-based Key Derivation Function)を計算します。

hkdf = get_authentication_key(
    username, # USER_ID_FOR_SRP
    password,
    srp_b, # hex_to_long(SRP_B)
    salt # SALT
)

username と password はクライアントがユーザーから取得する値です。実際には username は先ほど教えており、それが ChallengeParameters にも入ってくるのでそれを使います。

srp_b と salt も ChallengeParameters として与えられるものです。

では、この get_authentication_key 関数の実装を紐解いていきます。

3-3. u の算出

RFCu = SHA1(PAD(A) | PAD(B)) と書かれているものです。

変数名を補足すると

  • A は srp_a
  • B は srp_b

です。実装は割と見たまんまですね。

u = hex_to_long(hex_hash(pad_hex(srp_a) + pad_hex(srp_b)))

最後に10進数に変換しているところがポイントでしょうか。

以下は冒頭で定義した関数たちです。

  • hex_to_long(hex_string)
  • hex_hash(hex_string)
  • pad_hex(long_int)

3-4. x の算出

RFCx = SHA1(s | SHA1(I | ":" | P)) と書かれているものです。

x = hex_to_long(hex_hash(pad_hex(salt) + full_password_hash))

それぞれ、RFC の式とは以下のように対応します。

  • s は salt
  • I はユーザー名
    • ただしユーザープール ID 末尾の文字列と Cognito ユーザー名を結合したもの
  • P はパスワード

なので、full_password_hash は以下のようにして得られます。

pool_id = 'ap-northeast-1_Hogehogexx'
full_password = f'{pool_id.split('_')[1]}{username}:{password}'
full_password_hash = hash_sha256(full_password.encode('utf-8'))

3-5. s の算出

これまでに u と x を計算したので、それらを使って <premaster secret> を計算します。

<premaster secret> = (B - (k * g^x)) ^ (a + (u * x)) % N

コード上では s と名付けられています。 さっき RFC に出てくた s は SALT のことなので混同しないようご注意下さい。

また、RFC の式にはしれっと書いてありますが、ここで初めて使われる k は k = SHA1(N | PAD(g)) のように計算します。

k の実装はこれです。

k = hex_to_long(hex_hash('00' + N + '0' + g))

では <premaster secret> の実装を見てみましょう。こんな流れです。

g_mod_pow_xn = g ** x % N

int_value2 = srp_b - k * g_mod_pow_xn

s = int_value2 ** (a + u * x) % N

1本の式に落とし込むとこうなります。

s = (srp_b - (k * g ** x) % N) ** (a + (u * x)) % N

お気づきでしょうか。以下に再掲する RFC の式と少し違います。(Python の文法に寄せて ^** に書き換えています*1

s = (srp_b - (k * g ** x)) ** (a + (u * x)) % N

RFC では g ** x と書いているところが g ** x (% N) ですね。 理由はおそらく、計算量を減らすためだと思います。 N で割った余りは0 から N - 1 の範囲に抑えられるので、いい感じに大きさを制御できます。

で、これが同じ値になるかですが、なります。付録 1. に証明を書いておきました。多分正しいと思います。

3-6. hkdf の算出

これで <parameter secret> つまり s がわかりました。 u と s を使って hkdf を導出します。

hkdf = compute_hkdf(
    bytearray.fromhex(pad_hex(s)),
    bytearray.fromhex(pad_hex(long_to_hex(u)))
)

そして、compute_hkdf 関数は以下のように定義されています。

info_bits = bytearray('Caldera Derived Key', 'utf-8') # ソースコードにベタ書きされいる

def compute_hkdf(ikm, salt):
    prk = hmac.new(salt, ikm, hashlib.sha256).digest()
    info_bits_update = info_bits + bytearray(chr(1), 'utf-8') # 結果: bytearray(b'Caldera Derived Key\x01')
    hmac_hash = hmac.new(prk, info_bits_update, hashlib.sha256).digest()
    return hmac_hash[:16]  # 先頭 16 文字

まず、 <parameter secret> をキーとして u(srp_a と srp_b をつなぎ合わせたもの)をハッシュ化し、そいつを prk とします。

さらにその結果を使って、 info_bits_update(中身は固定値で bytearray(b'Caldera Derived Key\x01'))をまたハッシュ化します。

その結果の先頭16文字を hkbf として、呼び出し元に返します。

ここの流れがよくわからないので、また今度調べておきます。

en.wikipedia.org

3-7. PASSWORD_CLAIM_SIGNATURE の算出

hkdf が得られたので、それをもとに Challenge Parameters の各値をハッシュ化していきます。

まず、ChallengeParameter の SECRET_BLOCK は base64 エンコードされているらしく、そいつをデコードします。

secret_block_bytes = base64.standard_b64decode(secret_block_b64)

そして、以下をバイト文字列として結合したものを、ハッシュ化する対象の文字列とします。

  • UserPoolId(の末尾)
  • USERNAME
  • デコードした SECRET_BLOCK
  • タイムスタンプ

実装はそのまんまです。

msg = bytearray(pool_id.split('_')[1], 'utf-8') + \
    bytearray(user_id_for_srp, 'utf-8') + \
    bytearray(secret_block_bytes) + \
    bytearray(timestamp, 'utf-8')

この msg をハッシュ化した hmac_obj を生成します。ハッシュ化に用いるキーは先ほど計算した hkdf です。

hmac_obj = hmac.new(hkdf, msg, digestmod=hashlib.sha256)

結果を base64 のバイト列にエンコードして、署名文字列とします。

signature_string = base64.standard_b64encode(hmac_obj.digest())

これでようやく PASSWORD_VERIFIER チャレンジに応答できます。

4. RespondToAuthChallenge API の実行

ChallengeName を PASSWORD_VERIFIER とし、RespondToAuthChallenge API を実行します。

これで認証成功です。 JWT として IdToken が返ってきますので、それを使って好きにできます。

client.respond_to_auth_challenge(
    ClientId=client_id,
    ChallengeName='PASSWORD_VERIFIER',
    ChallengeResponses={
        'TIMESTAMP': timestamp,
        'USERNAME': user_id_for_srp, # ChallengeParameters からそのまま
        'PASSWORD_CLAIM_SECRET_BLOCK': secret_block_b64, # ChallengeParameters からそのまま
        'PASSWORD_CLAIM_SIGNATURE': signature_string.decode('utf-8') # バイト列を文字列にする
    }
)

お疲れ様でした。

base64 周りの補足

base64.standard_b64decode() は下のように base64.b64decode() と同じものみたいです。

def standard_b64encode(s):
    """Encode bytes-like object s using the standard Base64 alphabet.

    The result is returned as a bytes object.
    """
    return b64encode(s)

def standard_b64decode(s):
    """Decode bytes encoded with the standard Base64 alphabet.

    Argument s is a bytes-like object or ASCII string to decode.  The result
    is returned as a bytes object.  A binascii.Error is raised if the input
    is incorrectly padded.  Characters that are not in the standard alphabet
    are discarded prior to the padding check.
    """
    return b64decode(s)

最後に

MFA とかが絡むともう1段階あるので、その辺もいつかフォローしたいです。

付録

付録 1. s の二通りの計算が同じであることの証明

注意: 冪乗は RFC に寄せて ^ と記載しています。 以下のように定義される s_1 と s_2 が、s_1 - s_2 = 0 になればよい。

s_1 = (srp_b - (k * g ^ x) % N) ^ (a + u * x) % N
s_2 = (srp_b - (k * g ^ x)) ^ (a + u * x) % N

こんな感じで % で書いた部分をバラしていく。

s = a % N => s = a - m * N

計算するとこうなる。

s_1 = (srp_b + k * m1 * N - k * g ^ x) ^ (a + u * x) - n1 * N
s_2 = (srp_b - k * g ^ x) ^ (a + u * x) - n2 * N

ここで、 m1, n1, n2 は適当な整数である。

見づらいので、適当に置き換えを行う。

X = a + u * x
Y = srp_b - k * g ^ x

とすればシンプルになる。

s_1 = (Y + k * m1 * N) ^ X - n1 * N
s_2 = Y ^ X - n2 * N

s_1 - s_2 を考えると、

s_1 - s_2 = (Y + k * m1 * N) ^ X - Y ^ X - (n1 -n2) * N

となり、二項定理を使えば B ^ A の項がうまい具合に打ち消されて N が掛かった項だけ残ることがわかる。

従って、以下のように書ける。

(s_1 - s_2) % N = 0 ... (*)

ここで、 s_1 と s_2 は元々 N で割った余りだったことを思い出せば

s_1 % N = s_1
s_2 % N = s_2

よって (*) は

(s_1 - s_2) % N = 0
=> s_1 % N - s_2 % N = 0
=> s_1 - s_2 = 0

したがって s_1 = s_2 が示せた。

最初にこれを証明しようとしたとき、(*) まではすぐ辿り着けたんですが、その後のs_1 % N = s_1, s_2 % N = s_2 が成り立つことに気づかなくて、わからね〜〜〜と言いながら一週間ぐらい放置していました。余談です。

*1:記法がごっちゃになっているとコメントいただきまして修正しました!ありがとうございます!そしてすみません!