LCOV - code coverage report
Current view: top level - src/common - scram-common.c (source / functions) Hit Total Coverage
Test: PostgreSQL 18devel Lines: 75 114 65.8 %
Date: 2024-11-21 08:14:44 Functions: 5 5 100.0 %
Legend: Lines: hit not hit

          Line data    Source code
       1             : /*-------------------------------------------------------------------------
       2             :  * scram-common.c
       3             :  *      Shared frontend/backend code for SCRAM authentication
       4             :  *
       5             :  * This contains the common low-level functions needed in both frontend and
       6             :  * backend, for implement the Salted Challenge Response Authentication
       7             :  * Mechanism (SCRAM), per IETF's RFC 5802.
       8             :  *
       9             :  * Portions Copyright (c) 2017-2024, PostgreSQL Global Development Group
      10             :  *
      11             :  * IDENTIFICATION
      12             :  *    src/common/scram-common.c
      13             :  *
      14             :  *-------------------------------------------------------------------------
      15             :  */
      16             : #ifndef FRONTEND
      17             : #include "postgres.h"
      18             : #else
      19             : #include "postgres_fe.h"
      20             : #endif
      21             : 
      22             : #include "common/base64.h"
      23             : #include "common/hmac.h"
      24             : #include "common/scram-common.h"
      25             : #ifndef FRONTEND
      26             : #include "miscadmin.h"
      27             : #endif
      28             : #include "port/pg_bswap.h"
      29             : 
      30             : /*
      31             :  * Calculate SaltedPassword.
      32             :  *
      33             :  * The password should already be normalized by SASLprep.  Returns 0 on
      34             :  * success, -1 on failure with *errstr pointing to a message about the
      35             :  * error details.
      36             :  */
      37             : int
      38         228 : scram_SaltedPassword(const char *password,
      39             :                      pg_cryptohash_type hash_type, int key_length,
      40             :                      const char *salt, int saltlen, int iterations,
      41             :                      uint8 *result, const char **errstr)
      42             : {
      43         228 :     int         password_len = strlen(password);
      44         228 :     uint32      one = pg_hton32(1);
      45             :     int         i,
      46             :                 j;
      47             :     uint8       Ui[SCRAM_MAX_KEY_LEN];
      48             :     uint8       Ui_prev[SCRAM_MAX_KEY_LEN];
      49         228 :     pg_hmac_ctx *hmac_ctx = pg_hmac_create(hash_type);
      50             : 
      51         228 :     if (hmac_ctx == NULL)
      52             :     {
      53           0 :         *errstr = pg_hmac_error(NULL);  /* returns OOM */
      54           0 :         return -1;
      55             :     }
      56             : 
      57             :     /*
      58             :      * Iterate hash calculation of HMAC entry using given salt.  This is
      59             :      * essentially PBKDF2 (see RFC2898) with HMAC() as the pseudorandom
      60             :      * function.
      61             :      */
      62             : 
      63             :     /* First iteration */
      64         456 :     if (pg_hmac_init(hmac_ctx, (uint8 *) password, password_len) < 0 ||
      65         456 :         pg_hmac_update(hmac_ctx, (uint8 *) salt, saltlen) < 0 ||
      66         456 :         pg_hmac_update(hmac_ctx, (uint8 *) &one, sizeof(uint32)) < 0 ||
      67         228 :         pg_hmac_final(hmac_ctx, Ui_prev, key_length) < 0)
      68             :     {
      69           0 :         *errstr = pg_hmac_error(hmac_ctx);
      70           0 :         pg_hmac_free(hmac_ctx);
      71           0 :         return -1;
      72             :     }
      73             : 
      74         228 :     memcpy(result, Ui_prev, key_length);
      75             : 
      76             :     /* Subsequent iterations */
      77      848124 :     for (i = 2; i <= iterations; i++)
      78             :     {
      79             : #ifndef FRONTEND
      80             :         /*
      81             :          * Make sure that this is interruptible as scram_iterations could be
      82             :          * set to a large value.
      83             :          */
      84      528322 :         CHECK_FOR_INTERRUPTS();
      85             : #endif
      86             : 
      87     1695792 :         if (pg_hmac_init(hmac_ctx, (uint8 *) password, password_len) < 0 ||
      88     1695792 :             pg_hmac_update(hmac_ctx, (uint8 *) Ui_prev, key_length) < 0 ||
      89      847896 :             pg_hmac_final(hmac_ctx, Ui, key_length) < 0)
      90             :         {
      91           0 :             *errstr = pg_hmac_error(hmac_ctx);
      92           0 :             pg_hmac_free(hmac_ctx);
      93           0 :             return -1;
      94             :         }
      95             : 
      96    27980568 :         for (j = 0; j < key_length; j++)
      97    27132672 :             result[j] ^= Ui[j];
      98      847896 :         memcpy(Ui_prev, Ui, key_length);
      99             :     }
     100             : 
     101         228 :     pg_hmac_free(hmac_ctx);
     102         228 :     return 0;
     103             : }
     104             : 
     105             : 
     106             : /*
     107             :  * Calculate hash for a NULL-terminated string. (The NULL terminator is
     108             :  * not included in the hash).  Returns 0 on success, -1 on failure with *errstr
     109             :  * pointing to a message about the error details.
     110             :  */
     111             : int
     112         258 : scram_H(const uint8 *input, pg_cryptohash_type hash_type, int key_length,
     113             :         uint8 *result, const char **errstr)
     114             : {
     115             :     pg_cryptohash_ctx *ctx;
     116             : 
     117         258 :     ctx = pg_cryptohash_create(hash_type);
     118         258 :     if (ctx == NULL)
     119             :     {
     120           0 :         *errstr = pg_cryptohash_error(NULL);    /* returns OOM */
     121           0 :         return -1;
     122             :     }
     123             : 
     124         516 :     if (pg_cryptohash_init(ctx) < 0 ||
     125         516 :         pg_cryptohash_update(ctx, input, key_length) < 0 ||
     126         258 :         pg_cryptohash_final(ctx, result, key_length) < 0)
     127             :     {
     128           0 :         *errstr = pg_cryptohash_error(ctx);
     129           0 :         pg_cryptohash_free(ctx);
     130           0 :         return -1;
     131             :     }
     132             : 
     133         258 :     pg_cryptohash_free(ctx);
     134         258 :     return 0;
     135             : }
     136             : 
     137             : /*
     138             :  * Calculate ClientKey.  Returns 0 on success, -1 on failure with *errstr
     139             :  * pointing to a message about the error details.
     140             :  */
     141             : int
     142         178 : scram_ClientKey(const uint8 *salted_password,
     143             :                 pg_cryptohash_type hash_type, int key_length,
     144             :                 uint8 *result, const char **errstr)
     145             : {
     146         178 :     pg_hmac_ctx *ctx = pg_hmac_create(hash_type);
     147             : 
     148         178 :     if (ctx == NULL)
     149             :     {
     150           0 :         *errstr = pg_hmac_error(NULL);  /* returns OOM */
     151           0 :         return -1;
     152             :     }
     153             : 
     154         356 :     if (pg_hmac_init(ctx, salted_password, key_length) < 0 ||
     155         356 :         pg_hmac_update(ctx, (uint8 *) "Client Key", strlen("Client Key")) < 0 ||
     156         178 :         pg_hmac_final(ctx, result, key_length) < 0)
     157             :     {
     158           0 :         *errstr = pg_hmac_error(ctx);
     159           0 :         pg_hmac_free(ctx);
     160           0 :         return -1;
     161             :     }
     162             : 
     163         178 :     pg_hmac_free(ctx);
     164         178 :     return 0;
     165             : }
     166             : 
     167             : /*
     168             :  * Calculate ServerKey.  Returns 0 on success, -1 on failure with *errstr
     169             :  * pointing to a message about the error details.
     170             :  */
     171             : int
     172         216 : scram_ServerKey(const uint8 *salted_password,
     173             :                 pg_cryptohash_type hash_type, int key_length,
     174             :                 uint8 *result, const char **errstr)
     175             : {
     176         216 :     pg_hmac_ctx *ctx = pg_hmac_create(hash_type);
     177             : 
     178         216 :     if (ctx == NULL)
     179             :     {
     180           0 :         *errstr = pg_hmac_error(NULL);  /* returns OOM */
     181           0 :         return -1;
     182             :     }
     183             : 
     184         432 :     if (pg_hmac_init(ctx, salted_password, key_length) < 0 ||
     185         432 :         pg_hmac_update(ctx, (uint8 *) "Server Key", strlen("Server Key")) < 0 ||
     186         216 :         pg_hmac_final(ctx, result, key_length) < 0)
     187             :     {
     188           0 :         *errstr = pg_hmac_error(ctx);
     189           0 :         pg_hmac_free(ctx);
     190           0 :         return -1;
     191             :     }
     192             : 
     193         216 :     pg_hmac_free(ctx);
     194         216 :     return 0;
     195             : }
     196             : 
     197             : 
     198             : /*
     199             :  * Construct a SCRAM secret, for storing in pg_authid.rolpassword.
     200             :  *
     201             :  * The password should already have been processed with SASLprep, if necessary!
     202             :  *
     203             :  * If iterations is 0, default number of iterations is used.  The result is
     204             :  * palloc'd or malloc'd, so caller is responsible for freeing it.
     205             :  *
     206             :  * On error, returns NULL and sets *errstr to point to a message about the
     207             :  * error details.
     208             :  */
     209             : char *
     210          98 : scram_build_secret(pg_cryptohash_type hash_type, int key_length,
     211             :                    const char *salt, int saltlen, int iterations,
     212             :                    const char *password, const char **errstr)
     213             : {
     214             :     uint8       salted_password[SCRAM_MAX_KEY_LEN];
     215             :     uint8       stored_key[SCRAM_MAX_KEY_LEN];
     216             :     uint8       server_key[SCRAM_MAX_KEY_LEN];
     217             :     char       *result;
     218             :     char       *p;
     219             :     int         maxlen;
     220             :     int         encoded_salt_len;
     221             :     int         encoded_stored_len;
     222             :     int         encoded_server_len;
     223             :     int         encoded_result;
     224             : 
     225             :     /* Only this hash method is supported currently */
     226             :     Assert(hash_type == PG_SHA256);
     227             : 
     228             :     Assert(iterations > 0);
     229             : 
     230             :     /* Calculate StoredKey and ServerKey */
     231          98 :     if (scram_SaltedPassword(password, hash_type, key_length,
     232             :                              salt, saltlen, iterations,
     233          98 :                              salted_password, errstr) < 0 ||
     234          98 :         scram_ClientKey(salted_password, hash_type, key_length,
     235          98 :                         stored_key, errstr) < 0 ||
     236          98 :         scram_H(stored_key, hash_type, key_length,
     237          98 :                 stored_key, errstr) < 0 ||
     238          98 :         scram_ServerKey(salted_password, hash_type, key_length,
     239             :                         server_key, errstr) < 0)
     240             :     {
     241             :         /* errstr is filled already here */
     242             : #ifdef FRONTEND
     243           0 :         return NULL;
     244             : #else
     245           0 :         elog(ERROR, "could not calculate stored key and server key: %s",
     246             :              *errstr);
     247             : #endif
     248             :     }
     249             : 
     250             :     /*----------
     251             :      * The format is:
     252             :      * SCRAM-SHA-256$<iteration count>:<salt>$<StoredKey>:<ServerKey>
     253             :      *----------
     254             :      */
     255          98 :     encoded_salt_len = pg_b64_enc_len(saltlen);
     256          98 :     encoded_stored_len = pg_b64_enc_len(key_length);
     257          98 :     encoded_server_len = pg_b64_enc_len(key_length);
     258             : 
     259          98 :     maxlen = strlen("SCRAM-SHA-256") + 1
     260             :         + 10 + 1                /* iteration count */
     261             :         + encoded_salt_len + 1  /* Base64-encoded salt */
     262          98 :         + encoded_stored_len + 1    /* Base64-encoded StoredKey */
     263          98 :         + encoded_server_len + 1;   /* Base64-encoded ServerKey */
     264             : 
     265             : #ifdef FRONTEND
     266           2 :     result = malloc(maxlen);
     267           2 :     if (!result)
     268             :     {
     269           0 :         *errstr = _("out of memory");
     270           0 :         return NULL;
     271             :     }
     272             : #else
     273          96 :     result = palloc(maxlen);
     274             : #endif
     275             : 
     276          98 :     p = result + sprintf(result, "SCRAM-SHA-256$%d:", iterations);
     277             : 
     278             :     /* salt */
     279          98 :     encoded_result = pg_b64_encode(salt, saltlen, p, encoded_salt_len);
     280          98 :     if (encoded_result < 0)
     281             :     {
     282           0 :         *errstr = _("could not encode salt");
     283             : #ifdef FRONTEND
     284           0 :         free(result);
     285           0 :         return NULL;
     286             : #else
     287           0 :         elog(ERROR, "%s", *errstr);
     288             : #endif
     289             :     }
     290          98 :     p += encoded_result;
     291          98 :     *(p++) = '$';
     292             : 
     293             :     /* stored key */
     294          98 :     encoded_result = pg_b64_encode((char *) stored_key, key_length, p,
     295             :                                    encoded_stored_len);
     296          98 :     if (encoded_result < 0)
     297             :     {
     298           0 :         *errstr = _("could not encode stored key");
     299             : #ifdef FRONTEND
     300           0 :         free(result);
     301           0 :         return NULL;
     302             : #else
     303           0 :         elog(ERROR, "%s", *errstr);
     304             : #endif
     305             :     }
     306             : 
     307          98 :     p += encoded_result;
     308          98 :     *(p++) = ':';
     309             : 
     310             :     /* server key */
     311          98 :     encoded_result = pg_b64_encode((char *) server_key, key_length, p,
     312             :                                    encoded_server_len);
     313          98 :     if (encoded_result < 0)
     314             :     {
     315           0 :         *errstr = _("could not encode server key");
     316             : #ifdef FRONTEND
     317           0 :         free(result);
     318           0 :         return NULL;
     319             : #else
     320           0 :         elog(ERROR, "%s", *errstr);
     321             : #endif
     322             :     }
     323             : 
     324          98 :     p += encoded_result;
     325          98 :     *(p++) = '\0';
     326             : 
     327             :     Assert(p - result <= maxlen);
     328             : 
     329          98 :     return result;
     330             : }

Generated by: LCOV version 1.14