LCOV - code coverage report
Current view: top level - src/common - scram-common.c (source / functions) Coverage Total Hit
Test: PostgreSQL 19devel Lines: 65.8 % 114 75
Test Date: 2026-03-11 21:15:11 Functions: 100.0 % 5 5
Legend: Lines:     hit not hit

            Line data    Source code
       1              : /*-------------------------------------------------------------------------
       2              :  * scram-common.c
       3              :  *      Shared frontend/backend code for SCRAM authentication
       4              :  *
       5              :  * This contains the common low-level functions needed in both frontend and
       6              :  * backend, for implement the Salted Challenge Response Authentication
       7              :  * Mechanism (SCRAM), per IETF's RFC 5802.
       8              :  *
       9              :  * Portions Copyright (c) 2017-2026, PostgreSQL Global Development Group
      10              :  *
      11              :  * IDENTIFICATION
      12              :  *    src/common/scram-common.c
      13              :  *
      14              :  *-------------------------------------------------------------------------
      15              :  */
      16              : #ifndef FRONTEND
      17              : #include "postgres.h"
      18              : #else
      19              : #include "postgres_fe.h"
      20              : #endif
      21              : 
      22              : #include "common/base64.h"
      23              : #include "common/hmac.h"
      24              : #include "common/scram-common.h"
      25              : #ifndef FRONTEND
      26              : #include "miscadmin.h"
      27              : #endif
      28              : #include "port/pg_bswap.h"
      29              : 
      30              : /*
      31              :  * Calculate SaltedPassword.
      32              :  *
      33              :  * The password should already be normalized by SASLprep.  Returns 0 on
      34              :  * success, -1 on failure with *errstr pointing to a message about the
      35              :  * error details.
      36              :  */
      37              : int
      38          143 : scram_SaltedPassword(const char *password,
      39              :                      pg_cryptohash_type hash_type, int key_length,
      40              :                      const uint8 *salt, int saltlen, int iterations,
      41              :                      uint8 *result, const char **errstr)
      42              : {
      43          143 :     int         password_len = strlen(password);
      44          143 :     uint32      one = pg_hton32(1);
      45              :     int         i,
      46              :                 j;
      47              :     uint8       Ui[SCRAM_MAX_KEY_LEN];
      48              :     uint8       Ui_prev[SCRAM_MAX_KEY_LEN];
      49          143 :     pg_hmac_ctx *hmac_ctx = pg_hmac_create(hash_type);
      50              : 
      51          143 :     if (hmac_ctx == NULL)
      52              :     {
      53            0 :         *errstr = pg_hmac_error(NULL);  /* returns OOM */
      54            0 :         return -1;
      55              :     }
      56              : 
      57              :     /*
      58              :      * Iterate hash calculation of HMAC entry using given salt.  This is
      59              :      * essentially PBKDF2 (see RFC2898) with HMAC() as the pseudorandom
      60              :      * function.
      61              :      */
      62              : 
      63              :     /* First iteration */
      64          286 :     if (pg_hmac_init(hmac_ctx, (const uint8 *) password, password_len) < 0 ||
      65          286 :         pg_hmac_update(hmac_ctx, salt, saltlen) < 0 ||
      66          286 :         pg_hmac_update(hmac_ctx, (uint8 *) &one, sizeof(uint32)) < 0 ||
      67          143 :         pg_hmac_final(hmac_ctx, Ui_prev, key_length) < 0)
      68              :     {
      69            0 :         *errstr = pg_hmac_error(hmac_ctx);
      70            0 :         pg_hmac_free(hmac_ctx);
      71            0 :         return -1;
      72              :     }
      73              : 
      74          143 :     memcpy(result, Ui_prev, key_length);
      75              : 
      76              :     /* Subsequent iterations */
      77       542846 :     for (i = 1; i < iterations; i++)
      78              :     {
      79              : #ifndef FRONTEND
      80              :         /*
      81              :          * Make sure that this is interruptible as scram_iterations could be
      82              :          * set to a large value.
      83              :          */
      84       309206 :         CHECK_FOR_INTERRUPTS();
      85              : #endif
      86              : 
      87      1085406 :         if (pg_hmac_init(hmac_ctx, (const uint8 *) password, password_len) < 0 ||
      88      1085406 :             pg_hmac_update(hmac_ctx, (uint8 *) Ui_prev, key_length) < 0 ||
      89       542703 :             pg_hmac_final(hmac_ctx, Ui, key_length) < 0)
      90              :         {
      91            0 :             *errstr = pg_hmac_error(hmac_ctx);
      92            0 :             pg_hmac_free(hmac_ctx);
      93            0 :             return -1;
      94              :         }
      95              : 
      96     17909199 :         for (j = 0; j < key_length; j++)
      97     17366496 :             result[j] ^= Ui[j];
      98       542703 :         memcpy(Ui_prev, Ui, key_length);
      99              :     }
     100              : 
     101          143 :     pg_hmac_free(hmac_ctx);
     102          143 :     return 0;
     103              : }
     104              : 
     105              : 
     106              : /*
     107              :  * Calculate hash for a NULL-terminated string. (The NULL terminator is
     108              :  * not included in the hash).  Returns 0 on success, -1 on failure with *errstr
     109              :  * pointing to a message about the error details.
     110              :  */
     111              : int
     112          185 : scram_H(const uint8 *input, pg_cryptohash_type hash_type, int key_length,
     113              :         uint8 *result, const char **errstr)
     114              : {
     115              :     pg_cryptohash_ctx *ctx;
     116              : 
     117          185 :     ctx = pg_cryptohash_create(hash_type);
     118          185 :     if (ctx == NULL)
     119              :     {
     120            0 :         *errstr = pg_cryptohash_error(NULL);    /* returns OOM */
     121            0 :         return -1;
     122              :     }
     123              : 
     124          370 :     if (pg_cryptohash_init(ctx) < 0 ||
     125          370 :         pg_cryptohash_update(ctx, input, key_length) < 0 ||
     126          185 :         pg_cryptohash_final(ctx, result, key_length) < 0)
     127              :     {
     128            0 :         *errstr = pg_cryptohash_error(ctx);
     129            0 :         pg_cryptohash_free(ctx);
     130            0 :         return -1;
     131              :     }
     132              : 
     133          185 :     pg_cryptohash_free(ctx);
     134          185 :     return 0;
     135              : }
     136              : 
     137              : /*
     138              :  * Calculate ClientKey.  Returns 0 on success, -1 on failure with *errstr
     139              :  * pointing to a message about the error details.
     140              :  */
     141              : int
     142          116 : scram_ClientKey(const uint8 *salted_password,
     143              :                 pg_cryptohash_type hash_type, int key_length,
     144              :                 uint8 *result, const char **errstr)
     145              : {
     146          116 :     pg_hmac_ctx *ctx = pg_hmac_create(hash_type);
     147              : 
     148          116 :     if (ctx == NULL)
     149              :     {
     150            0 :         *errstr = pg_hmac_error(NULL);  /* returns OOM */
     151            0 :         return -1;
     152              :     }
     153              : 
     154          232 :     if (pg_hmac_init(ctx, salted_password, key_length) < 0 ||
     155          232 :         pg_hmac_update(ctx, (uint8 *) "Client Key", strlen("Client Key")) < 0 ||
     156          116 :         pg_hmac_final(ctx, result, key_length) < 0)
     157              :     {
     158            0 :         *errstr = pg_hmac_error(ctx);
     159            0 :         pg_hmac_free(ctx);
     160            0 :         return -1;
     161              :     }
     162              : 
     163          116 :     pg_hmac_free(ctx);
     164          116 :     return 0;
     165              : }
     166              : 
     167              : /*
     168              :  * Calculate ServerKey.  Returns 0 on success, -1 on failure with *errstr
     169              :  * pointing to a message about the error details.
     170              :  */
     171              : int
     172          136 : scram_ServerKey(const uint8 *salted_password,
     173              :                 pg_cryptohash_type hash_type, int key_length,
     174              :                 uint8 *result, const char **errstr)
     175              : {
     176          136 :     pg_hmac_ctx *ctx = pg_hmac_create(hash_type);
     177              : 
     178          136 :     if (ctx == NULL)
     179              :     {
     180            0 :         *errstr = pg_hmac_error(NULL);  /* returns OOM */
     181            0 :         return -1;
     182              :     }
     183              : 
     184          272 :     if (pg_hmac_init(ctx, salted_password, key_length) < 0 ||
     185          272 :         pg_hmac_update(ctx, (uint8 *) "Server Key", strlen("Server Key")) < 0 ||
     186          136 :         pg_hmac_final(ctx, result, key_length) < 0)
     187              :     {
     188            0 :         *errstr = pg_hmac_error(ctx);
     189            0 :         pg_hmac_free(ctx);
     190            0 :         return -1;
     191              :     }
     192              : 
     193          136 :     pg_hmac_free(ctx);
     194          136 :     return 0;
     195              : }
     196              : 
     197              : 
     198              : /*
     199              :  * Construct a SCRAM secret, for storing in pg_authid.rolpassword.
     200              :  *
     201              :  * The password should already have been processed with SASLprep, if necessary!
     202              :  *
     203              :  * The result is palloc'd or malloc'd, so caller is responsible for freeing it.
     204              :  *
     205              :  * On error, returns NULL and sets *errstr to point to a message about the
     206              :  * error details.
     207              :  */
     208              : char *
     209           58 : scram_build_secret(pg_cryptohash_type hash_type, int key_length,
     210              :                    const uint8 *salt, int saltlen, int iterations,
     211              :                    const char *password, const char **errstr)
     212              : {
     213              :     uint8       salted_password[SCRAM_MAX_KEY_LEN];
     214              :     uint8       stored_key[SCRAM_MAX_KEY_LEN];
     215              :     uint8       server_key[SCRAM_MAX_KEY_LEN];
     216              :     char       *result;
     217              :     char       *p;
     218              :     int         maxlen;
     219              :     int         encoded_salt_len;
     220              :     int         encoded_stored_len;
     221              :     int         encoded_server_len;
     222              :     int         encoded_result;
     223              : 
     224              :     /* Only this hash method is supported currently */
     225              :     Assert(hash_type == PG_SHA256);
     226              : 
     227              :     Assert(iterations > 0);
     228              : 
     229              :     /* Calculate StoredKey and ServerKey */
     230           58 :     if (scram_SaltedPassword(password, hash_type, key_length,
     231              :                              salt, saltlen, iterations,
     232           58 :                              salted_password, errstr) < 0 ||
     233           58 :         scram_ClientKey(salted_password, hash_type, key_length,
     234           58 :                         stored_key, errstr) < 0 ||
     235           58 :         scram_H(stored_key, hash_type, key_length,
     236           58 :                 stored_key, errstr) < 0 ||
     237           58 :         scram_ServerKey(salted_password, hash_type, key_length,
     238              :                         server_key, errstr) < 0)
     239              :     {
     240              :         /* errstr is filled already here */
     241              : #ifdef FRONTEND
     242            0 :         return NULL;
     243              : #else
     244            0 :         elog(ERROR, "could not calculate stored key and server key: %s",
     245              :              *errstr);
     246              : #endif
     247              :     }
     248              : 
     249              :     /*----------
     250              :      * The format is:
     251              :      * SCRAM-SHA-256$<iteration count>:<salt>$<StoredKey>:<ServerKey>
     252              :      *----------
     253              :      */
     254           58 :     encoded_salt_len = pg_b64_enc_len(saltlen);
     255           58 :     encoded_stored_len = pg_b64_enc_len(key_length);
     256           58 :     encoded_server_len = pg_b64_enc_len(key_length);
     257              : 
     258           58 :     maxlen = strlen("SCRAM-SHA-256") + 1
     259              :         + 10 + 1                /* iteration count */
     260              :         + encoded_salt_len + 1  /* Base64-encoded salt */
     261           58 :         + encoded_stored_len + 1    /* Base64-encoded StoredKey */
     262           58 :         + encoded_server_len + 1;   /* Base64-encoded ServerKey */
     263              : 
     264              : #ifdef FRONTEND
     265            1 :     result = malloc(maxlen);
     266            1 :     if (!result)
     267              :     {
     268            0 :         *errstr = _("out of memory");
     269            0 :         return NULL;
     270              :     }
     271              : #else
     272           57 :     result = palloc(maxlen);
     273              : #endif
     274              : 
     275           58 :     p = result + sprintf(result, "SCRAM-SHA-256$%d:", iterations);
     276              : 
     277              :     /* salt */
     278           58 :     encoded_result = pg_b64_encode(salt, saltlen, p, encoded_salt_len);
     279           58 :     if (encoded_result < 0)
     280              :     {
     281            0 :         *errstr = _("could not encode salt");
     282              : #ifdef FRONTEND
     283            0 :         free(result);
     284            0 :         return NULL;
     285              : #else
     286            0 :         elog(ERROR, "%s", *errstr);
     287              : #endif
     288              :     }
     289           58 :     p += encoded_result;
     290           58 :     *(p++) = '$';
     291              : 
     292              :     /* stored key */
     293           58 :     encoded_result = pg_b64_encode(stored_key, key_length, p,
     294              :                                    encoded_stored_len);
     295           58 :     if (encoded_result < 0)
     296              :     {
     297            0 :         *errstr = _("could not encode stored key");
     298              : #ifdef FRONTEND
     299            0 :         free(result);
     300            0 :         return NULL;
     301              : #else
     302            0 :         elog(ERROR, "%s", *errstr);
     303              : #endif
     304              :     }
     305              : 
     306           58 :     p += encoded_result;
     307           58 :     *(p++) = ':';
     308              : 
     309              :     /* server key */
     310           58 :     encoded_result = pg_b64_encode(server_key, key_length, p,
     311              :                                    encoded_server_len);
     312           58 :     if (encoded_result < 0)
     313              :     {
     314            0 :         *errstr = _("could not encode server key");
     315              : #ifdef FRONTEND
     316            0 :         free(result);
     317            0 :         return NULL;
     318              : #else
     319            0 :         elog(ERROR, "%s", *errstr);
     320              : #endif
     321              :     }
     322              : 
     323           58 :     p += encoded_result;
     324           58 :     *(p++) = '\0';
     325              : 
     326              :     Assert(p - result <= maxlen);
     327              : 
     328           58 :     return result;
     329              : }
        

Generated by: LCOV version 2.0-1