LCOV - code coverage report
Current view: top level - src/common - scram-common.c (source / functions) Hit Total Coverage
Test: PostgreSQL 13devel Lines: 83 96 86.5 %
Date: 2019-09-22 07:07:17 Functions: 8 8 100.0 %
Legend: Lines: hit not hit

          Line data    Source code
       1             : /*-------------------------------------------------------------------------
       2             :  * scram-common.c
       3             :  *      Shared frontend/backend code for SCRAM authentication
       4             :  *
       5             :  * This contains the common low-level functions needed in both frontend and
       6             :  * backend, for implement the Salted Challenge Response Authentication
       7             :  * Mechanism (SCRAM), per IETF's RFC 5802.
       8             :  *
       9             :  * Portions Copyright (c) 2017-2019, PostgreSQL Global Development Group
      10             :  *
      11             :  * IDENTIFICATION
      12             :  *    src/common/scram-common.c
      13             :  *
      14             :  *-------------------------------------------------------------------------
      15             :  */
      16             : #ifndef FRONTEND
      17             : #include "postgres.h"
      18             : #else
      19             : #include "postgres_fe.h"
      20             : #endif
      21             : 
      22             : #include "common/base64.h"
      23             : #include "common/scram-common.h"
      24             : #include "port/pg_bswap.h"
      25             : 
      26             : #define HMAC_IPAD 0x36
      27             : #define HMAC_OPAD 0x5C
      28             : 
      29             : /*
      30             :  * Calculate HMAC per RFC2104.
      31             :  *
      32             :  * The hash function used is SHA-256.
      33             :  */
      34             : void
      35      368904 : scram_HMAC_init(scram_HMAC_ctx *ctx, const uint8 *key, int keylen)
      36             : {
      37             :     uint8       k_ipad[SHA256_HMAC_B];
      38             :     int         i;
      39             :     uint8       keybuf[SCRAM_KEY_LEN];
      40             : 
      41             :     /*
      42             :      * If the key is longer than the block size (64 bytes for SHA-256), pass
      43             :      * it through SHA-256 once to shrink it down.
      44             :      */
      45      368904 :     if (keylen > SHA256_HMAC_B)
      46             :     {
      47             :         pg_sha256_ctx sha256_ctx;
      48             : 
      49       32768 :         pg_sha256_init(&sha256_ctx);
      50       32768 :         pg_sha256_update(&sha256_ctx, key, keylen);
      51       32768 :         pg_sha256_final(&sha256_ctx, keybuf);
      52       32768 :         key = keybuf;
      53       32768 :         keylen = SCRAM_KEY_LEN;
      54             :     }
      55             : 
      56      368904 :     memset(k_ipad, HMAC_IPAD, SHA256_HMAC_B);
      57      368904 :     memset(ctx->k_opad, HMAC_OPAD, SHA256_HMAC_B);
      58             : 
      59     3908104 :     for (i = 0; i < keylen; i++)
      60             :     {
      61     3539200 :         k_ipad[i] ^= key[i];
      62     3539200 :         ctx->k_opad[i] ^= key[i];
      63             :     }
      64             : 
      65             :     /* tmp = H(K XOR ipad, text) */
      66      368904 :     pg_sha256_init(&ctx->sha256ctx);
      67      368904 :     pg_sha256_update(&ctx->sha256ctx, k_ipad, SHA256_HMAC_B);
      68      368904 : }
      69             : 
      70             : /*
      71             :  * Update HMAC calculation
      72             :  * The hash function used is SHA-256.
      73             :  */
      74             : void
      75      369426 : scram_HMAC_update(scram_HMAC_ctx *ctx, const char *str, int slen)
      76             : {
      77      369426 :     pg_sha256_update(&ctx->sha256ctx, (const uint8 *) str, slen);
      78      369426 : }
      79             : 
      80             : /*
      81             :  * Finalize HMAC calculation.
      82             :  * The hash function used is SHA-256.
      83             :  */
      84             : void
      85      368904 : scram_HMAC_final(uint8 *result, scram_HMAC_ctx *ctx)
      86             : {
      87             :     uint8       h[SCRAM_KEY_LEN];
      88             : 
      89      368904 :     pg_sha256_final(&ctx->sha256ctx, h);
      90             : 
      91             :     /* H(K XOR opad, tmp) */
      92      368904 :     pg_sha256_init(&ctx->sha256ctx);
      93      368904 :     pg_sha256_update(&ctx->sha256ctx, ctx->k_opad, SHA256_HMAC_B);
      94      368904 :     pg_sha256_update(&ctx->sha256ctx, h, SCRAM_KEY_LEN);
      95      368904 :     pg_sha256_final(&ctx->sha256ctx, result);
      96      368904 : }
      97             : 
      98             : /*
      99             :  * Calculate SaltedPassword.
     100             :  *
     101             :  * The password should already be normalized by SASLprep.
     102             :  */
     103             : void
     104          90 : scram_SaltedPassword(const char *password,
     105             :                      const char *salt, int saltlen, int iterations,
     106             :                      uint8 *result)
     107             : {
     108          90 :     int         password_len = strlen(password);
     109          90 :     uint32      one = pg_hton32(1);
     110             :     int         i,
     111             :                 j;
     112             :     uint8       Ui[SCRAM_KEY_LEN];
     113             :     uint8       Ui_prev[SCRAM_KEY_LEN];
     114             :     scram_HMAC_ctx hmac_ctx;
     115             : 
     116             :     /*
     117             :      * Iterate hash calculation of HMAC entry using given salt.  This is
     118             :      * essentially PBKDF2 (see RFC2898) with HMAC() as the pseudorandom
     119             :      * function.
     120             :      */
     121             : 
     122             :     /* First iteration */
     123          90 :     scram_HMAC_init(&hmac_ctx, (uint8 *) password, password_len);
     124          90 :     scram_HMAC_update(&hmac_ctx, salt, saltlen);
     125          90 :     scram_HMAC_update(&hmac_ctx, (char *) &one, sizeof(uint32));
     126          90 :     scram_HMAC_final(Ui_prev, &hmac_ctx);
     127          90 :     memcpy(result, Ui_prev, SCRAM_KEY_LEN);
     128             : 
     129             :     /* Subsequent iterations */
     130      368640 :     for (i = 2; i <= iterations; i++)
     131             :     {
     132      368550 :         scram_HMAC_init(&hmac_ctx, (uint8 *) password, password_len);
     133      368550 :         scram_HMAC_update(&hmac_ctx, (const char *) Ui_prev, SCRAM_KEY_LEN);
     134      368550 :         scram_HMAC_final(Ui, &hmac_ctx);
     135    12162150 :         for (j = 0; j < SCRAM_KEY_LEN; j++)
     136    11793600 :             result[j] ^= Ui[j];
     137      368550 :         memcpy(Ui_prev, Ui, SCRAM_KEY_LEN);
     138             :     }
     139          90 : }
     140             : 
     141             : 
     142             : /*
     143             :  * Calculate SHA-256 hash for a NULL-terminated string. (The NULL terminator is
     144             :  * not included in the hash).
     145             :  */
     146             : void
     147         108 : scram_H(const uint8 *input, int len, uint8 *result)
     148             : {
     149             :     pg_sha256_ctx ctx;
     150             : 
     151         108 :     pg_sha256_init(&ctx);
     152         108 :     pg_sha256_update(&ctx, input, len);
     153         108 :     pg_sha256_final(&ctx, result);
     154         108 : }
     155             : 
     156             : /*
     157             :  * Calculate ClientKey.
     158             :  */
     159             : void
     160          76 : scram_ClientKey(const uint8 *salted_password, uint8 *result)
     161             : {
     162             :     scram_HMAC_ctx ctx;
     163             : 
     164          76 :     scram_HMAC_init(&ctx, salted_password, SCRAM_KEY_LEN);
     165          76 :     scram_HMAC_update(&ctx, "Client Key", strlen("Client Key"));
     166          76 :     scram_HMAC_final(result, &ctx);
     167          76 : }
     168             : 
     169             : /*
     170             :  * Calculate ServerKey.
     171             :  */
     172             : void
     173          80 : scram_ServerKey(const uint8 *salted_password, uint8 *result)
     174             : {
     175             :     scram_HMAC_ctx ctx;
     176             : 
     177          80 :     scram_HMAC_init(&ctx, salted_password, SCRAM_KEY_LEN);
     178          80 :     scram_HMAC_update(&ctx, "Server Key", strlen("Server Key"));
     179          80 :     scram_HMAC_final(result, &ctx);
     180          80 : }
     181             : 
     182             : 
     183             : /*
     184             :  * Construct a verifier string for SCRAM, stored in pg_authid.rolpassword.
     185             :  *
     186             :  * The password should already have been processed with SASLprep, if necessary!
     187             :  *
     188             :  * If iterations is 0, default number of iterations is used.  The result is
     189             :  * palloc'd or malloc'd, so caller is responsible for freeing it.
     190             :  */
     191             : char *
     192          44 : scram_build_verifier(const char *salt, int saltlen, int iterations,
     193             :                      const char *password)
     194             : {
     195             :     uint8       salted_password[SCRAM_KEY_LEN];
     196             :     uint8       stored_key[SCRAM_KEY_LEN];
     197             :     uint8       server_key[SCRAM_KEY_LEN];
     198             :     char       *result;
     199             :     char       *p;
     200             :     int         maxlen;
     201             :     int         encoded_salt_len;
     202             :     int         encoded_stored_len;
     203             :     int         encoded_server_len;
     204             :     int         encoded_result;
     205             : 
     206          44 :     if (iterations <= 0)
     207           0 :         iterations = SCRAM_DEFAULT_ITERATIONS;
     208             : 
     209             :     /* Calculate StoredKey and ServerKey */
     210          44 :     scram_SaltedPassword(password, salt, saltlen, iterations,
     211             :                          salted_password);
     212          44 :     scram_ClientKey(salted_password, stored_key);
     213          44 :     scram_H(stored_key, SCRAM_KEY_LEN, stored_key);
     214             : 
     215          44 :     scram_ServerKey(salted_password, server_key);
     216             : 
     217             :     /*----------
     218             :      * The format is:
     219             :      * SCRAM-SHA-256$<iteration count>:<salt>$<StoredKey>:<ServerKey>
     220             :      *----------
     221             :      */
     222          44 :     encoded_salt_len = pg_b64_enc_len(saltlen);
     223          44 :     encoded_stored_len = pg_b64_enc_len(SCRAM_KEY_LEN);
     224          44 :     encoded_server_len = pg_b64_enc_len(SCRAM_KEY_LEN);
     225             : 
     226          44 :     maxlen = strlen("SCRAM-SHA-256") + 1
     227             :         + 10 + 1                /* iteration count */
     228             :         + encoded_salt_len + 1  /* Base64-encoded salt */
     229          44 :         + encoded_stored_len + 1    /* Base64-encoded StoredKey */
     230          44 :         + encoded_server_len + 1;   /* Base64-encoded ServerKey */
     231             : 
     232             : #ifdef FRONTEND
     233           0 :     result = malloc(maxlen);
     234           0 :     if (!result)
     235           0 :         return NULL;
     236             : #else
     237          44 :     result = palloc(maxlen);
     238             : #endif
     239             : 
     240          44 :     p = result + sprintf(result, "SCRAM-SHA-256$%d:", iterations);
     241             : 
     242             :     /* salt */
     243          44 :     encoded_result = pg_b64_encode(salt, saltlen, p, encoded_salt_len);
     244          44 :     if (encoded_result < 0)
     245             :     {
     246             : #ifdef FRONTEND
     247           0 :         free(result);
     248           0 :         return NULL;
     249             : #else
     250           0 :         elog(ERROR, "could not encode salt");
     251             : #endif
     252             :     }
     253          44 :     p += encoded_result;
     254          44 :     *(p++) = '$';
     255             : 
     256             :     /* stored key */
     257          44 :     encoded_result = pg_b64_encode((char *) stored_key, SCRAM_KEY_LEN, p,
     258             :                                    encoded_stored_len);
     259          44 :     if (encoded_result < 0)
     260             :     {
     261             : #ifdef FRONTEND
     262           0 :         free(result);
     263           0 :         return NULL;
     264             : #else
     265           0 :         elog(ERROR, "could not encode stored key");
     266             : #endif
     267             :     }
     268             : 
     269          44 :     p += encoded_result;
     270          44 :     *(p++) = ':';
     271             : 
     272             :     /* server key */
     273          44 :     encoded_result = pg_b64_encode((char *) server_key, SCRAM_KEY_LEN, p,
     274             :                                    encoded_server_len);
     275          44 :     if (encoded_result < 0)
     276             :     {
     277             : #ifdef FRONTEND
     278           0 :         free(result);
     279           0 :         return NULL;
     280             : #else
     281           0 :         elog(ERROR, "could not encode server key");
     282             : #endif
     283             :     }
     284             : 
     285          44 :     p += encoded_result;
     286          44 :     *(p++) = '\0';
     287             : 
     288             :     Assert(p - result <= maxlen);
     289             : 
     290          44 :     return result;
     291             : }

Generated by: LCOV version 1.13