LCOV - code coverage report
Current view: top level - src/backend/libpq - auth-scram.c (source / functions) Hit Total Coverage
Test: PostgreSQL 18devel Lines: 311 398 78.1 %
Date: 2025-01-18 04:15:08 Functions: 17 19 89.5 %
Legend: Lines: hit not hit

          Line data    Source code
       1             : /*-------------------------------------------------------------------------
       2             :  *
       3             :  * auth-scram.c
       4             :  *    Server-side implementation of the SASL SCRAM-SHA-256 mechanism.
       5             :  *
       6             :  * See the following RFCs for more details:
       7             :  * - RFC 5802: https://tools.ietf.org/html/rfc5802
       8             :  * - RFC 5803: https://tools.ietf.org/html/rfc5803
       9             :  * - RFC 7677: https://tools.ietf.org/html/rfc7677
      10             :  *
      11             :  * Here are some differences:
      12             :  *
      13             :  * - Username from the authentication exchange is not used. The client
      14             :  *   should send an empty string as the username.
      15             :  *
      16             :  * - If the password isn't valid UTF-8, or contains characters prohibited
      17             :  *   by the SASLprep profile, we skip the SASLprep pre-processing and use
      18             :  *   the raw bytes in calculating the hash.
      19             :  *
      20             :  * - If channel binding is used, the channel binding type is always
      21             :  *   "tls-server-end-point".  The spec says the default is "tls-unique"
      22             :  *   (RFC 5802, section 6.1. Default Channel Binding), but there are some
      23             :  *   problems with that.  Firstly, not all SSL libraries provide an API to
      24             :  *   get the TLS Finished message, required to use "tls-unique".  Secondly,
      25             :  *   "tls-unique" is not specified for TLS v1.3, and as of this writing,
      26             :  *   it's not clear if there will be a replacement.  We could support both
      27             :  *   "tls-server-end-point" and "tls-unique", but for our use case,
      28             :  *   "tls-unique" doesn't really have any advantages.  The main advantage
      29             :  *   of "tls-unique" would be that it works even if the server doesn't
      30             :  *   have a certificate, but PostgreSQL requires a server certificate
      31             :  *   whenever SSL is used, anyway.
      32             :  *
      33             :  *
      34             :  * The password stored in pg_authid consists of the iteration count, salt,
      35             :  * StoredKey and ServerKey.
      36             :  *
      37             :  * SASLprep usage
      38             :  * --------------
      39             :  *
      40             :  * One notable difference to the SCRAM specification is that while the
      41             :  * specification dictates that the password is in UTF-8, and prohibits
      42             :  * certain characters, we are more lenient.  If the password isn't a valid
      43             :  * UTF-8 string, or contains prohibited characters, the raw bytes are used
      44             :  * to calculate the hash instead, without SASLprep processing.  This is
      45             :  * because PostgreSQL supports other encodings too, and the encoding being
      46             :  * used during authentication is undefined (client_encoding isn't set until
      47             :  * after authentication).  In effect, we try to interpret the password as
      48             :  * UTF-8 and apply SASLprep processing, but if it looks invalid, we assume
      49             :  * that it's in some other encoding.
      50             :  *
      51             :  * In the worst case, we misinterpret a password that's in a different
      52             :  * encoding as being Unicode, because it happens to consists entirely of
      53             :  * valid UTF-8 bytes, and we apply Unicode normalization to it.  As long
      54             :  * as we do that consistently, that will not lead to failed logins.
      55             :  * Fortunately, the UTF-8 byte sequences that are ignored by SASLprep
      56             :  * don't correspond to any commonly used characters in any of the other
      57             :  * supported encodings, so it should not lead to any significant loss in
      58             :  * entropy, even if the normalization is incorrectly applied to a
      59             :  * non-UTF-8 password.
      60             :  *
      61             :  * Error handling
      62             :  * --------------
      63             :  *
      64             :  * Don't reveal user information to an unauthenticated client.  We don't
      65             :  * want an attacker to be able to probe whether a particular username is
      66             :  * valid.  In SCRAM, the server has to read the salt and iteration count
      67             :  * from the user's stored secret, and send it to the client.  To avoid
      68             :  * revealing whether a user exists, when the client tries to authenticate
      69             :  * with a username that doesn't exist, or doesn't have a valid SCRAM
      70             :  * secret in pg_authid, we create a fake salt and iteration count
      71             :  * on-the-fly, and proceed with the authentication with that.  In the end,
      72             :  * we'll reject the attempt, as if an incorrect password was given.  When
      73             :  * we are performing a "mock" authentication, the 'doomed' flag in
      74             :  * scram_state is set.
      75             :  *
      76             :  * In the error messages, avoid printing strings from the client, unless
      77             :  * you check that they are pure ASCII.  We don't want an unauthenticated
      78             :  * attacker to be able to spam the logs with characters that are not valid
      79             :  * to the encoding being used, whatever that is.  We cannot avoid that in
      80             :  * general, after logging in, but let's do what we can here.
      81             :  *
      82             :  *
      83             :  * Portions Copyright (c) 1996-2025, PostgreSQL Global Development Group
      84             :  * Portions Copyright (c) 1994, Regents of the University of California
      85             :  *
      86             :  * src/backend/libpq/auth-scram.c
      87             :  *
      88             :  *-------------------------------------------------------------------------
      89             :  */
      90             : #include "postgres.h"
      91             : 
      92             : #include <unistd.h>
      93             : 
      94             : #include "access/xlog.h"
      95             : #include "catalog/pg_control.h"
      96             : #include "common/base64.h"
      97             : #include "common/hmac.h"
      98             : #include "common/saslprep.h"
      99             : #include "common/scram-common.h"
     100             : #include "common/sha2.h"
     101             : #include "libpq/crypt.h"
     102             : #include "libpq/sasl.h"
     103             : #include "libpq/scram.h"
     104             : #include "miscadmin.h"
     105             : 
     106             : static void scram_get_mechanisms(Port *port, StringInfo buf);
     107             : static void *scram_init(Port *port, const char *selected_mech,
     108             :                         const char *shadow_pass);
     109             : static int  scram_exchange(void *opaq, const char *input, int inputlen,
     110             :                            char **output, int *outputlen,
     111             :                            const char **logdetail);
     112             : 
     113             : /* Mechanism declaration */
     114             : const pg_be_sasl_mech pg_be_scram_mech = {
     115             :     scram_get_mechanisms,
     116             :     scram_init,
     117             :     scram_exchange,
     118             : 
     119             :     PG_MAX_SASL_MESSAGE_LENGTH
     120             : };
     121             : 
     122             : /*
     123             :  * Status data for a SCRAM authentication exchange.  This should be kept
     124             :  * internal to this file.
     125             :  */
     126             : typedef enum
     127             : {
     128             :     SCRAM_AUTH_INIT,
     129             :     SCRAM_AUTH_SALT_SENT,
     130             :     SCRAM_AUTH_FINISHED,
     131             : } scram_state_enum;
     132             : 
     133             : typedef struct
     134             : {
     135             :     scram_state_enum state;
     136             : 
     137             :     const char *username;       /* username from startup packet */
     138             : 
     139             :     Port       *port;
     140             :     bool        channel_binding_in_use;
     141             : 
     142             :     /* State data depending on the hash type */
     143             :     pg_cryptohash_type hash_type;
     144             :     int         key_length;
     145             : 
     146             :     int         iterations;
     147             :     char       *salt;           /* base64-encoded */
     148             :     uint8       ClientKey[SCRAM_MAX_KEY_LEN];
     149             :     uint8       StoredKey[SCRAM_MAX_KEY_LEN];
     150             :     uint8       ServerKey[SCRAM_MAX_KEY_LEN];
     151             : 
     152             :     /* Fields of the first message from client */
     153             :     char        cbind_flag;
     154             :     char       *client_first_message_bare;
     155             :     char       *client_username;
     156             :     char       *client_nonce;
     157             : 
     158             :     /* Fields from the last message from client */
     159             :     char       *client_final_message_without_proof;
     160             :     char       *client_final_nonce;
     161             :     char        ClientProof[SCRAM_MAX_KEY_LEN];
     162             : 
     163             :     /* Fields generated in the server */
     164             :     char       *server_first_message;
     165             :     char       *server_nonce;
     166             : 
     167             :     /*
     168             :      * If something goes wrong during the authentication, or we are performing
     169             :      * a "mock" authentication (see comments at top of file), the 'doomed'
     170             :      * flag is set.  A reason for the failure, for the server log, is put in
     171             :      * 'logdetail'.
     172             :      */
     173             :     bool        doomed;
     174             :     char       *logdetail;
     175             : } scram_state;
     176             : 
     177             : static void read_client_first_message(scram_state *state, const char *input);
     178             : static void read_client_final_message(scram_state *state, const char *input);
     179             : static char *build_server_first_message(scram_state *state);
     180             : static char *build_server_final_message(scram_state *state);
     181             : static bool verify_client_proof(scram_state *state);
     182             : static bool verify_final_nonce(scram_state *state);
     183             : static void mock_scram_secret(const char *username, pg_cryptohash_type *hash_type,
     184             :                               int *iterations, int *key_length, char **salt,
     185             :                               uint8 *stored_key, uint8 *server_key);
     186             : static bool is_scram_printable(char *p);
     187             : static char *sanitize_char(char c);
     188             : static char *sanitize_str(const char *s);
     189             : static char *scram_mock_salt(const char *username,
     190             :                              pg_cryptohash_type hash_type,
     191             :                              int key_length);
     192             : 
     193             : /*
     194             :  * The number of iterations to use when generating new secrets.
     195             :  */
     196             : int         scram_sha_256_iterations = SCRAM_SHA_256_DEFAULT_ITERATIONS;
     197             : 
     198             : /*
     199             :  * Get a list of SASL mechanisms that this module supports.
     200             :  *
     201             :  * For the convenience of building the FE/BE packet that lists the
     202             :  * mechanisms, the names are appended to the given StringInfo buffer,
     203             :  * separated by '\0' bytes.
     204             :  */
     205             : static void
     206         118 : scram_get_mechanisms(Port *port, StringInfo buf)
     207             : {
     208             :     /*
     209             :      * Advertise the mechanisms in decreasing order of importance.  So the
     210             :      * channel-binding variants go first, if they are supported.  Channel
     211             :      * binding is only supported with SSL.
     212             :      */
     213             : #ifdef USE_SSL
     214         118 :     if (port->ssl_in_use)
     215             :     {
     216          14 :         appendStringInfoString(buf, SCRAM_SHA_256_PLUS_NAME);
     217          14 :         appendStringInfoChar(buf, '\0');
     218             :     }
     219             : #endif
     220         118 :     appendStringInfoString(buf, SCRAM_SHA_256_NAME);
     221         118 :     appendStringInfoChar(buf, '\0');
     222         118 : }
     223             : 
     224             : /*
     225             :  * Initialize a new SCRAM authentication exchange status tracker.  This
     226             :  * needs to be called before doing any exchange.  It will be filled later
     227             :  * after the beginning of the exchange with authentication information.
     228             :  *
     229             :  * 'selected_mech' identifies the SASL mechanism that the client selected.
     230             :  * It should be one of the mechanisms that we support, as returned by
     231             :  * scram_get_mechanisms().
     232             :  *
     233             :  * 'shadow_pass' is the role's stored secret, from pg_authid.rolpassword.
     234             :  * The username was provided by the client in the startup message, and is
     235             :  * available in port->user_name.  If 'shadow_pass' is NULL, we still perform
     236             :  * an authentication exchange, but it will fail, as if an incorrect password
     237             :  * was given.
     238             :  */
     239             : static void *
     240          96 : scram_init(Port *port, const char *selected_mech, const char *shadow_pass)
     241             : {
     242             :     scram_state *state;
     243             :     bool        got_secret;
     244             : 
     245          96 :     state = (scram_state *) palloc0(sizeof(scram_state));
     246          96 :     state->port = port;
     247          96 :     state->state = SCRAM_AUTH_INIT;
     248             : 
     249             :     /*
     250             :      * Parse the selected mechanism.
     251             :      *
     252             :      * Note that if we don't support channel binding, or if we're not using
     253             :      * SSL at all, we would not have advertised the PLUS variant in the first
     254             :      * place.  If the client nevertheless tries to select it, it's a protocol
     255             :      * violation like selecting any other SASL mechanism we don't support.
     256             :      */
     257             : #ifdef USE_SSL
     258          96 :     if (strcmp(selected_mech, SCRAM_SHA_256_PLUS_NAME) == 0 && port->ssl_in_use)
     259          10 :         state->channel_binding_in_use = true;
     260             :     else
     261             : #endif
     262          86 :     if (strcmp(selected_mech, SCRAM_SHA_256_NAME) == 0)
     263          86 :         state->channel_binding_in_use = false;
     264             :     else
     265           0 :         ereport(ERROR,
     266             :                 (errcode(ERRCODE_PROTOCOL_VIOLATION),
     267             :                  errmsg("client selected an invalid SASL authentication mechanism")));
     268             : 
     269             :     /*
     270             :      * Parse the stored secret.
     271             :      */
     272          96 :     if (shadow_pass)
     273             :     {
     274          96 :         int         password_type = get_password_type(shadow_pass);
     275             : 
     276          96 :         if (password_type == PASSWORD_TYPE_SCRAM_SHA_256)
     277             :         {
     278          94 :             if (parse_scram_secret(shadow_pass, &state->iterations,
     279             :                                    &state->hash_type, &state->key_length,
     280             :                                    &state->salt,
     281          94 :                                    state->StoredKey,
     282          94 :                                    state->ServerKey))
     283          94 :                 got_secret = true;
     284             :             else
     285             :             {
     286             :                 /*
     287             :                  * The password looked like a SCRAM secret, but could not be
     288             :                  * parsed.
     289             :                  */
     290           0 :                 ereport(LOG,
     291             :                         (errmsg("invalid SCRAM secret for user \"%s\"",
     292             :                                 state->port->user_name)));
     293           0 :                 got_secret = false;
     294             :             }
     295             :         }
     296             :         else
     297             :         {
     298             :             /*
     299             :              * The user doesn't have SCRAM secret. (You cannot do SCRAM
     300             :              * authentication with an MD5 hash.)
     301             :              */
     302           4 :             state->logdetail = psprintf(_("User \"%s\" does not have a valid SCRAM secret."),
     303           2 :                                         state->port->user_name);
     304           2 :             got_secret = false;
     305             :         }
     306             :     }
     307             :     else
     308             :     {
     309             :         /*
     310             :          * The caller requested us to perform a dummy authentication.  This is
     311             :          * considered normal, since the caller requested it, so don't set log
     312             :          * detail.
     313             :          */
     314           0 :         got_secret = false;
     315             :     }
     316             : 
     317             :     /*
     318             :      * If the user did not have a valid SCRAM secret, we still go through the
     319             :      * motions with a mock one, and fail as if the client supplied an
     320             :      * incorrect password.  This is to avoid revealing information to an
     321             :      * attacker.
     322             :      */
     323          96 :     if (!got_secret)
     324             :     {
     325           2 :         mock_scram_secret(state->port->user_name, &state->hash_type,
     326             :                           &state->iterations, &state->key_length,
     327             :                           &state->salt,
     328           2 :                           state->StoredKey, state->ServerKey);
     329           2 :         state->doomed = true;
     330             :     }
     331             : 
     332          96 :     return state;
     333             : }
     334             : 
     335             : /*
     336             :  * Continue a SCRAM authentication exchange.
     337             :  *
     338             :  * 'input' is the SCRAM payload sent by the client.  On the first call,
     339             :  * 'input' contains the "Initial Client Response" that the client sent as
     340             :  * part of the SASLInitialResponse message, or NULL if no Initial Client
     341             :  * Response was given.  (The SASL specification distinguishes between an
     342             :  * empty response and non-existing one.)  On subsequent calls, 'input'
     343             :  * cannot be NULL.  For convenience in this function, the caller must
     344             :  * ensure that there is a null terminator at input[inputlen].
     345             :  *
     346             :  * The next message to send to client is saved in 'output', for a length
     347             :  * of 'outputlen'.  In the case of an error, optionally store a palloc'd
     348             :  * string at *logdetail that will be sent to the postmaster log (but not
     349             :  * the client).
     350             :  */
     351             : static int
     352         192 : scram_exchange(void *opaq, const char *input, int inputlen,
     353             :                char **output, int *outputlen, const char **logdetail)
     354             : {
     355         192 :     scram_state *state = (scram_state *) opaq;
     356             :     int         result;
     357             : 
     358         192 :     *output = NULL;
     359             : 
     360             :     /*
     361             :      * If the client didn't include an "Initial Client Response" in the
     362             :      * SASLInitialResponse message, send an empty challenge, to which the
     363             :      * client will respond with the same data that usually comes in the
     364             :      * Initial Client Response.
     365             :      */
     366         192 :     if (input == NULL)
     367             :     {
     368             :         Assert(state->state == SCRAM_AUTH_INIT);
     369             : 
     370           0 :         *output = pstrdup("");
     371           0 :         *outputlen = 0;
     372           0 :         return PG_SASL_EXCHANGE_CONTINUE;
     373             :     }
     374             : 
     375             :     /*
     376             :      * Check that the input length agrees with the string length of the input.
     377             :      * We can ignore inputlen after this.
     378             :      */
     379         192 :     if (inputlen == 0)
     380           0 :         ereport(ERROR,
     381             :                 (errcode(ERRCODE_PROTOCOL_VIOLATION),
     382             :                  errmsg("malformed SCRAM message"),
     383             :                  errdetail("The message is empty.")));
     384         192 :     if (inputlen != strlen(input))
     385           0 :         ereport(ERROR,
     386             :                 (errcode(ERRCODE_PROTOCOL_VIOLATION),
     387             :                  errmsg("malformed SCRAM message"),
     388             :                  errdetail("Message length does not match input length.")));
     389             : 
     390         192 :     switch (state->state)
     391             :     {
     392          96 :         case SCRAM_AUTH_INIT:
     393             : 
     394             :             /*
     395             :              * Initialization phase.  Receive the first message from client
     396             :              * and be sure that it parsed correctly.  Then send the challenge
     397             :              * to the client.
     398             :              */
     399          96 :             read_client_first_message(state, input);
     400             : 
     401             :             /* prepare message to send challenge */
     402          96 :             *output = build_server_first_message(state);
     403             : 
     404          96 :             state->state = SCRAM_AUTH_SALT_SENT;
     405          96 :             result = PG_SASL_EXCHANGE_CONTINUE;
     406          96 :             break;
     407             : 
     408          96 :         case SCRAM_AUTH_SALT_SENT:
     409             : 
     410             :             /*
     411             :              * Final phase for the server.  Receive the response to the
     412             :              * challenge previously sent, verify, and let the client know that
     413             :              * everything went well (or not).
     414             :              */
     415          96 :             read_client_final_message(state, input);
     416             : 
     417          96 :             if (!verify_final_nonce(state))
     418           0 :                 ereport(ERROR,
     419             :                         (errcode(ERRCODE_PROTOCOL_VIOLATION),
     420             :                          errmsg("invalid SCRAM response"),
     421             :                          errdetail("Nonce does not match.")));
     422             : 
     423             :             /*
     424             :              * Now check the final nonce and the client proof.
     425             :              *
     426             :              * If we performed a "mock" authentication that we knew would fail
     427             :              * from the get go, this is where we fail.
     428             :              *
     429             :              * The SCRAM specification includes an error code,
     430             :              * "invalid-proof", for authentication failure, but it also allows
     431             :              * erroring out in an application-specific way.  We choose to do
     432             :              * the latter, so that the error message for invalid password is
     433             :              * the same for all authentication methods.  The caller will call
     434             :              * ereport(), when we return PG_SASL_EXCHANGE_FAILURE with no
     435             :              * output.
     436             :              *
     437             :              * NB: the order of these checks is intentional.  We calculate the
     438             :              * client proof even in a mock authentication, even though it's
     439             :              * bound to fail, to thwart timing attacks to determine if a role
     440             :              * with the given name exists or not.
     441             :              */
     442          96 :             if (!verify_client_proof(state) || state->doomed)
     443             :             {
     444          10 :                 result = PG_SASL_EXCHANGE_FAILURE;
     445          10 :                 break;
     446             :             }
     447             : 
     448             :             /* Build final message for client */
     449          86 :             *output = build_server_final_message(state);
     450             : 
     451             :             /* Success! */
     452          86 :             result = PG_SASL_EXCHANGE_SUCCESS;
     453          86 :             state->state = SCRAM_AUTH_FINISHED;
     454          86 :             break;
     455             : 
     456           0 :         default:
     457           0 :             elog(ERROR, "invalid SCRAM exchange state");
     458             :             result = PG_SASL_EXCHANGE_FAILURE;
     459             :     }
     460             : 
     461         192 :     if (result == PG_SASL_EXCHANGE_FAILURE && state->logdetail && logdetail)
     462           2 :         *logdetail = state->logdetail;
     463             : 
     464         192 :     if (*output)
     465         182 :         *outputlen = strlen(*output);
     466             : 
     467         192 :     if (result == PG_SASL_EXCHANGE_SUCCESS && state->state == SCRAM_AUTH_FINISHED)
     468             :     {
     469          86 :         memcpy(MyProcPort->scram_ClientKey, state->ClientKey, sizeof(MyProcPort->scram_ClientKey));
     470          86 :         memcpy(MyProcPort->scram_ServerKey, state->ServerKey, sizeof(MyProcPort->scram_ServerKey));
     471          86 :         MyProcPort->has_scram_keys = true;
     472             :     }
     473             : 
     474         192 :     return result;
     475             : }
     476             : 
     477             : /*
     478             :  * Construct a SCRAM secret, for storing in pg_authid.rolpassword.
     479             :  *
     480             :  * The result is palloc'd, so caller is responsible for freeing it.
     481             :  */
     482             : char *
     483         102 : pg_be_scram_build_secret(const char *password)
     484             : {
     485             :     char       *prep_password;
     486             :     pg_saslprep_rc rc;
     487             :     char        saltbuf[SCRAM_DEFAULT_SALT_LEN];
     488             :     char       *result;
     489         102 :     const char *errstr = NULL;
     490             : 
     491             :     /*
     492             :      * Normalize the password with SASLprep.  If that doesn't work, because
     493             :      * the password isn't valid UTF-8 or contains prohibited characters, just
     494             :      * proceed with the original password.  (See comments at top of file.)
     495             :      */
     496         102 :     rc = pg_saslprep(password, &prep_password);
     497         102 :     if (rc == SASLPREP_SUCCESS)
     498         100 :         password = (const char *) prep_password;
     499             : 
     500             :     /* Generate random salt */
     501         102 :     if (!pg_strong_random(saltbuf, SCRAM_DEFAULT_SALT_LEN))
     502           0 :         ereport(ERROR,
     503             :                 (errcode(ERRCODE_INTERNAL_ERROR),
     504             :                  errmsg("could not generate random salt")));
     505             : 
     506         102 :     result = scram_build_secret(PG_SHA256, SCRAM_SHA_256_KEY_LEN,
     507             :                                 saltbuf, SCRAM_DEFAULT_SALT_LEN,
     508             :                                 scram_sha_256_iterations, password,
     509             :                                 &errstr);
     510             : 
     511         102 :     if (prep_password)
     512         100 :         pfree(prep_password);
     513             : 
     514         102 :     return result;
     515             : }
     516             : 
     517             : /*
     518             :  * Verify a plaintext password against a SCRAM secret.  This is used when
     519             :  * performing plaintext password authentication for a user that has a SCRAM
     520             :  * secret stored in pg_authid.
     521             :  */
     522             : bool
     523          52 : scram_verify_plain_password(const char *username, const char *password,
     524             :                             const char *secret)
     525             : {
     526             :     char       *encoded_salt;
     527             :     char       *salt;
     528             :     int         saltlen;
     529             :     int         iterations;
     530          52 :     int         key_length = 0;
     531             :     pg_cryptohash_type hash_type;
     532             :     uint8       salted_password[SCRAM_MAX_KEY_LEN];
     533             :     uint8       stored_key[SCRAM_MAX_KEY_LEN];
     534             :     uint8       server_key[SCRAM_MAX_KEY_LEN];
     535             :     uint8       computed_key[SCRAM_MAX_KEY_LEN];
     536             :     char       *prep_password;
     537             :     pg_saslprep_rc rc;
     538          52 :     const char *errstr = NULL;
     539             : 
     540          52 :     if (!parse_scram_secret(secret, &iterations, &hash_type, &key_length,
     541             :                             &encoded_salt, stored_key, server_key))
     542             :     {
     543             :         /*
     544             :          * The password looked like a SCRAM secret, but could not be parsed.
     545             :          */
     546           0 :         ereport(LOG,
     547             :                 (errmsg("invalid SCRAM secret for user \"%s\"", username)));
     548           0 :         return false;
     549             :     }
     550             : 
     551          52 :     saltlen = pg_b64_dec_len(strlen(encoded_salt));
     552          52 :     salt = palloc(saltlen);
     553          52 :     saltlen = pg_b64_decode(encoded_salt, strlen(encoded_salt), salt,
     554             :                             saltlen);
     555          52 :     if (saltlen < 0)
     556             :     {
     557           0 :         ereport(LOG,
     558             :                 (errmsg("invalid SCRAM secret for user \"%s\"", username)));
     559           0 :         return false;
     560             :     }
     561             : 
     562             :     /* Normalize the password */
     563          52 :     rc = pg_saslprep(password, &prep_password);
     564          52 :     if (rc == SASLPREP_SUCCESS)
     565          52 :         password = prep_password;
     566             : 
     567             :     /* Compute Server Key based on the user-supplied plaintext password */
     568          52 :     if (scram_SaltedPassword(password, hash_type, key_length,
     569             :                              salt, saltlen, iterations,
     570          52 :                              salted_password, &errstr) < 0 ||
     571          52 :         scram_ServerKey(salted_password, hash_type, key_length,
     572             :                         computed_key, &errstr) < 0)
     573             :     {
     574           0 :         elog(ERROR, "could not compute server key: %s", errstr);
     575             :     }
     576             : 
     577          52 :     if (prep_password)
     578          52 :         pfree(prep_password);
     579             : 
     580             :     /*
     581             :      * Compare the secret's Server Key with the one computed from the
     582             :      * user-supplied password.
     583             :      */
     584          52 :     return memcmp(computed_key, server_key, key_length) == 0;
     585             : }
     586             : 
     587             : 
     588             : /*
     589             :  * Parse and validate format of given SCRAM secret.
     590             :  *
     591             :  * On success, the iteration count, salt, stored key, and server key are
     592             :  * extracted from the secret, and returned to the caller.  For 'stored_key'
     593             :  * and 'server_key', the caller must pass pre-allocated buffers of size
     594             :  * SCRAM_MAX_KEY_LEN.  Salt is returned as a base64-encoded, null-terminated
     595             :  * string.  The buffer for the salt is palloc'd by this function.
     596             :  *
     597             :  * Returns true if the SCRAM secret has been parsed, and false otherwise.
     598             :  */
     599             : bool
     600         808 : parse_scram_secret(const char *secret, int *iterations,
     601             :                    pg_cryptohash_type *hash_type, int *key_length,
     602             :                    char **salt, uint8 *stored_key, uint8 *server_key)
     603             : {
     604             :     char       *v;
     605             :     char       *p;
     606             :     char       *scheme_str;
     607             :     char       *salt_str;
     608             :     char       *iterations_str;
     609             :     char       *storedkey_str;
     610             :     char       *serverkey_str;
     611             :     int         decoded_len;
     612             :     char       *decoded_salt_buf;
     613             :     char       *decoded_stored_buf;
     614             :     char       *decoded_server_buf;
     615             : 
     616             :     /*
     617             :      * The secret is of form:
     618             :      *
     619             :      * SCRAM-SHA-256$<iterations>:<salt>$<storedkey>:<serverkey>
     620             :      */
     621         808 :     v = pstrdup(secret);
     622         808 :     scheme_str = strsep(&v, "$");
     623         808 :     if (v == NULL)
     624         222 :         goto invalid_secret;
     625         586 :     iterations_str = strsep(&v, ":");
     626         586 :     if (v == NULL)
     627          12 :         goto invalid_secret;
     628         574 :     salt_str = strsep(&v, "$");
     629         574 :     if (v == NULL)
     630           0 :         goto invalid_secret;
     631         574 :     storedkey_str = strsep(&v, ":");
     632         574 :     if (v == NULL)
     633           0 :         goto invalid_secret;
     634         574 :     serverkey_str = v;
     635             : 
     636             :     /* Parse the fields */
     637         574 :     if (strcmp(scheme_str, "SCRAM-SHA-256") != 0)
     638           0 :         goto invalid_secret;
     639         574 :     *hash_type = PG_SHA256;
     640         574 :     *key_length = SCRAM_SHA_256_KEY_LEN;
     641             : 
     642         574 :     errno = 0;
     643         574 :     *iterations = strtol(iterations_str, &p, 10);
     644         574 :     if (*p || errno != 0)
     645           0 :         goto invalid_secret;
     646             : 
     647             :     /*
     648             :      * Verify that the salt is in Base64-encoded format, by decoding it,
     649             :      * although we return the encoded version to the caller.
     650             :      */
     651         574 :     decoded_len = pg_b64_dec_len(strlen(salt_str));
     652         574 :     decoded_salt_buf = palloc(decoded_len);
     653         574 :     decoded_len = pg_b64_decode(salt_str, strlen(salt_str),
     654             :                                 decoded_salt_buf, decoded_len);
     655         574 :     if (decoded_len < 0)
     656           0 :         goto invalid_secret;
     657         574 :     *salt = pstrdup(salt_str);
     658             : 
     659             :     /*
     660             :      * Decode StoredKey and ServerKey.
     661             :      */
     662         574 :     decoded_len = pg_b64_dec_len(strlen(storedkey_str));
     663         574 :     decoded_stored_buf = palloc(decoded_len);
     664         574 :     decoded_len = pg_b64_decode(storedkey_str, strlen(storedkey_str),
     665             :                                 decoded_stored_buf, decoded_len);
     666         574 :     if (decoded_len != *key_length)
     667          12 :         goto invalid_secret;
     668         562 :     memcpy(stored_key, decoded_stored_buf, *key_length);
     669             : 
     670         562 :     decoded_len = pg_b64_dec_len(strlen(serverkey_str));
     671         562 :     decoded_server_buf = palloc(decoded_len);
     672         562 :     decoded_len = pg_b64_decode(serverkey_str, strlen(serverkey_str),
     673             :                                 decoded_server_buf, decoded_len);
     674         562 :     if (decoded_len != *key_length)
     675          12 :         goto invalid_secret;
     676         550 :     memcpy(server_key, decoded_server_buf, *key_length);
     677             : 
     678         550 :     return true;
     679             : 
     680         258 : invalid_secret:
     681         258 :     *salt = NULL;
     682         258 :     return false;
     683             : }
     684             : 
     685             : /*
     686             :  * Generate plausible SCRAM secret parameters for mock authentication.
     687             :  *
     688             :  * In a normal authentication, these are extracted from the secret
     689             :  * stored in the server.  This function generates values that look
     690             :  * realistic, for when there is no stored secret, using SCRAM-SHA-256.
     691             :  *
     692             :  * Like in parse_scram_secret(), for 'stored_key' and 'server_key', the
     693             :  * caller must pass pre-allocated buffers of size SCRAM_MAX_KEY_LEN, and
     694             :  * the buffer for the salt is palloc'd by this function.
     695             :  */
     696             : static void
     697           2 : mock_scram_secret(const char *username, pg_cryptohash_type *hash_type,
     698             :                   int *iterations, int *key_length, char **salt,
     699             :                   uint8 *stored_key, uint8 *server_key)
     700             : {
     701             :     char       *raw_salt;
     702             :     char       *encoded_salt;
     703             :     int         encoded_len;
     704             : 
     705             :     /* Enforce the use of SHA-256, which would be realistic enough */
     706           2 :     *hash_type = PG_SHA256;
     707           2 :     *key_length = SCRAM_SHA_256_KEY_LEN;
     708             : 
     709             :     /*
     710             :      * Generate deterministic salt.
     711             :      *
     712             :      * Note that we cannot reveal any information to an attacker here so the
     713             :      * error messages need to remain generic.  This should never fail anyway
     714             :      * as the salt generated for mock authentication uses the cluster's nonce
     715             :      * value.
     716             :      */
     717           2 :     raw_salt = scram_mock_salt(username, *hash_type, *key_length);
     718           2 :     if (raw_salt == NULL)
     719           0 :         elog(ERROR, "could not encode salt");
     720             : 
     721           2 :     encoded_len = pg_b64_enc_len(SCRAM_DEFAULT_SALT_LEN);
     722             :     /* don't forget the zero-terminator */
     723           2 :     encoded_salt = (char *) palloc(encoded_len + 1);
     724           2 :     encoded_len = pg_b64_encode(raw_salt, SCRAM_DEFAULT_SALT_LEN, encoded_salt,
     725             :                                 encoded_len);
     726             : 
     727           2 :     if (encoded_len < 0)
     728           0 :         elog(ERROR, "could not encode salt");
     729           2 :     encoded_salt[encoded_len] = '\0';
     730             : 
     731           2 :     *salt = encoded_salt;
     732           2 :     *iterations = SCRAM_SHA_256_DEFAULT_ITERATIONS;
     733             : 
     734             :     /* StoredKey and ServerKey are not used in a doomed authentication */
     735           2 :     memset(stored_key, 0, SCRAM_MAX_KEY_LEN);
     736           2 :     memset(server_key, 0, SCRAM_MAX_KEY_LEN);
     737           2 : }
     738             : 
     739             : /*
     740             :  * Read the value in a given SCRAM exchange message for given attribute.
     741             :  */
     742             : static char *
     743         394 : read_attr_value(char **input, char attr)
     744             : {
     745         394 :     char       *begin = *input;
     746             :     char       *end;
     747             : 
     748         394 :     if (*begin != attr)
     749           0 :         ereport(ERROR,
     750             :                 (errcode(ERRCODE_PROTOCOL_VIOLATION),
     751             :                  errmsg("malformed SCRAM message"),
     752             :                  errdetail("Expected attribute \"%c\" but found \"%s\".",
     753             :                            attr, sanitize_char(*begin))));
     754         394 :     begin++;
     755             : 
     756         394 :     if (*begin != '=')
     757           0 :         ereport(ERROR,
     758             :                 (errcode(ERRCODE_PROTOCOL_VIOLATION),
     759             :                  errmsg("malformed SCRAM message"),
     760             :                  errdetail("Expected character \"=\" for attribute \"%c\".", attr)));
     761         394 :     begin++;
     762             : 
     763         394 :     end = begin;
     764        8610 :     while (*end && *end != ',')
     765        8216 :         end++;
     766             : 
     767         394 :     if (*end)
     768             :     {
     769         298 :         *end = '\0';
     770         298 :         *input = end + 1;
     771             :     }
     772             :     else
     773          96 :         *input = end;
     774             : 
     775         394 :     return begin;
     776             : }
     777             : 
     778             : static bool
     779          96 : is_scram_printable(char *p)
     780             : {
     781             :     /*------
     782             :      * Printable characters, as defined by SCRAM spec: (RFC 5802)
     783             :      *
     784             :      *  printable       = %x21-2B / %x2D-7E
     785             :      *                    ;; Printable ASCII except ",".
     786             :      *                    ;; Note that any "printable" is also
     787             :      *                    ;; a valid "value".
     788             :      *------
     789             :      */
     790        2400 :     for (; *p; p++)
     791             :     {
     792        2304 :         if (*p < 0x21 || *p > 0x7E || *p == 0x2C /* comma */ )
     793           0 :             return false;
     794             :     }
     795          96 :     return true;
     796             : }
     797             : 
     798             : /*
     799             :  * Convert an arbitrary byte to printable form.  For error messages.
     800             :  *
     801             :  * If it's a printable ASCII character, print it as a single character.
     802             :  * otherwise, print it in hex.
     803             :  *
     804             :  * The returned pointer points to a static buffer.
     805             :  */
     806             : static char *
     807           0 : sanitize_char(char c)
     808             : {
     809             :     static char buf[5];
     810             : 
     811           0 :     if (c >= 0x21 && c <= 0x7E)
     812           0 :         snprintf(buf, sizeof(buf), "'%c'", c);
     813             :     else
     814           0 :         snprintf(buf, sizeof(buf), "0x%02x", (unsigned char) c);
     815           0 :     return buf;
     816             : }
     817             : 
     818             : /*
     819             :  * Convert an arbitrary string to printable form, for error messages.
     820             :  *
     821             :  * Anything that's not a printable ASCII character is replaced with
     822             :  * '?', and the string is truncated at 30 characters.
     823             :  *
     824             :  * The returned pointer points to a static buffer.
     825             :  */
     826             : static char *
     827           0 : sanitize_str(const char *s)
     828             : {
     829             :     static char buf[30 + 1];
     830             :     int         i;
     831             : 
     832           0 :     for (i = 0; i < sizeof(buf) - 1; i++)
     833             :     {
     834           0 :         char        c = s[i];
     835             : 
     836           0 :         if (c == '\0')
     837           0 :             break;
     838             : 
     839           0 :         if (c >= 0x21 && c <= 0x7E)
     840           0 :             buf[i] = c;
     841             :         else
     842           0 :             buf[i] = '?';
     843             :     }
     844           0 :     buf[i] = '\0';
     845           0 :     return buf;
     846             : }
     847             : 
     848             : /*
     849             :  * Read the next attribute and value in a SCRAM exchange message.
     850             :  *
     851             :  * The attribute character is set in *attr_p, the attribute value is the
     852             :  * return value.
     853             :  */
     854             : static char *
     855          96 : read_any_attr(char **input, char *attr_p)
     856             : {
     857          96 :     char       *begin = *input;
     858             :     char       *end;
     859          96 :     char        attr = *begin;
     860             : 
     861          96 :     if (attr == '\0')
     862           0 :         ereport(ERROR,
     863             :                 (errcode(ERRCODE_PROTOCOL_VIOLATION),
     864             :                  errmsg("malformed SCRAM message"),
     865             :                  errdetail("Attribute expected, but found end of string.")));
     866             : 
     867             :     /*------
     868             :      * attr-val        = ALPHA "=" value
     869             :      *                   ;; Generic syntax of any attribute sent
     870             :      *                   ;; by server or client
     871             :      *------
     872             :      */
     873          96 :     if (!((attr >= 'A' && attr <= 'Z') ||
     874          96 :           (attr >= 'a' && attr <= 'z')))
     875           0 :         ereport(ERROR,
     876             :                 (errcode(ERRCODE_PROTOCOL_VIOLATION),
     877             :                  errmsg("malformed SCRAM message"),
     878             :                  errdetail("Attribute expected, but found invalid character \"%s\".",
     879             :                            sanitize_char(attr))));
     880          96 :     if (attr_p)
     881          96 :         *attr_p = attr;
     882          96 :     begin++;
     883             : 
     884          96 :     if (*begin != '=')
     885           0 :         ereport(ERROR,
     886             :                 (errcode(ERRCODE_PROTOCOL_VIOLATION),
     887             :                  errmsg("malformed SCRAM message"),
     888             :                  errdetail("Expected character \"=\" for attribute \"%c\".", attr)));
     889          96 :     begin++;
     890             : 
     891          96 :     end = begin;
     892        4320 :     while (*end && *end != ',')
     893        4224 :         end++;
     894             : 
     895          96 :     if (*end)
     896             :     {
     897           0 :         *end = '\0';
     898           0 :         *input = end + 1;
     899             :     }
     900             :     else
     901          96 :         *input = end;
     902             : 
     903          96 :     return begin;
     904             : }
     905             : 
     906             : /*
     907             :  * Read and parse the first message from client in the context of a SCRAM
     908             :  * authentication exchange message.
     909             :  *
     910             :  * At this stage, any errors will be reported directly with ereport(ERROR).
     911             :  */
     912             : static void
     913          96 : read_client_first_message(scram_state *state, const char *input)
     914             : {
     915          96 :     char       *p = pstrdup(input);
     916             :     char       *channel_binding_type;
     917             : 
     918             : 
     919             :     /*------
     920             :      * The syntax for the client-first-message is: (RFC 5802)
     921             :      *
     922             :      * saslname        = 1*(value-safe-char / "=2C" / "=3D")
     923             :      *                   ;; Conforms to <value>.
     924             :      *
     925             :      * authzid         = "a=" saslname
     926             :      *                   ;; Protocol specific.
     927             :      *
     928             :      * cb-name         = 1*(ALPHA / DIGIT / "." / "-")
     929             :      *                    ;; See RFC 5056, Section 7.
     930             :      *                    ;; E.g., "tls-server-end-point" or
     931             :      *                    ;; "tls-unique".
     932             :      *
     933             :      * gs2-cbind-flag  = ("p=" cb-name) / "n" / "y"
     934             :      *                   ;; "n" -> client doesn't support channel binding.
     935             :      *                   ;; "y" -> client does support channel binding
     936             :      *                   ;;        but thinks the server does not.
     937             :      *                   ;; "p" -> client requires channel binding.
     938             :      *                   ;; The selected channel binding follows "p=".
     939             :      *
     940             :      * gs2-header      = gs2-cbind-flag "," [ authzid ] ","
     941             :      *                   ;; GS2 header for SCRAM
     942             :      *                   ;; (the actual GS2 header includes an optional
     943             :      *                   ;; flag to indicate that the GSS mechanism is not
     944             :      *                   ;; "standard", but since SCRAM is "standard", we
     945             :      *                   ;; don't include that flag).
     946             :      *
     947             :      * username        = "n=" saslname
     948             :      *                   ;; Usernames are prepared using SASLprep.
     949             :      *
     950             :      * reserved-mext  = "m=" 1*(value-char)
     951             :      *                   ;; Reserved for signaling mandatory extensions.
     952             :      *                   ;; The exact syntax will be defined in
     953             :      *                   ;; the future.
     954             :      *
     955             :      * nonce           = "r=" c-nonce [s-nonce]
     956             :      *                   ;; Second part provided by server.
     957             :      *
     958             :      * c-nonce         = printable
     959             :      *
     960             :      * client-first-message-bare =
     961             :      *                   [reserved-mext ","]
     962             :      *                   username "," nonce ["," extensions]
     963             :      *
     964             :      * client-first-message =
     965             :      *                   gs2-header client-first-message-bare
     966             :      *
     967             :      * For example:
     968             :      * n,,n=user,r=fyko+d2lbbFgONRv9qkxdawL
     969             :      *
     970             :      * The "n,," in the beginning means that the client doesn't support
     971             :      * channel binding, and no authzid is given.  "n=user" is the username.
     972             :      * However, in PostgreSQL the username is sent in the startup packet, and
     973             :      * the username in the SCRAM exchange is ignored.  libpq always sends it
     974             :      * as an empty string.  The last part, "r=fyko+d2lbbFgONRv9qkxdawL" is
     975             :      * the client nonce.
     976             :      *------
     977             :      */
     978             : 
     979             :     /*
     980             :      * Read gs2-cbind-flag.  (For details see also RFC 5802 Section 6 "Channel
     981             :      * Binding".)
     982             :      */
     983          96 :     state->cbind_flag = *p;
     984          96 :     switch (*p)
     985             :     {
     986          86 :         case 'n':
     987             : 
     988             :             /*
     989             :              * The client does not support channel binding or has simply
     990             :              * decided to not use it.  In that case just let it go.
     991             :              */
     992          86 :             if (state->channel_binding_in_use)
     993           0 :                 ereport(ERROR,
     994             :                         (errcode(ERRCODE_PROTOCOL_VIOLATION),
     995             :                          errmsg("malformed SCRAM message"),
     996             :                          errdetail("The client selected SCRAM-SHA-256-PLUS, but the SCRAM message does not include channel binding data.")));
     997             : 
     998          86 :             p++;
     999          86 :             if (*p != ',')
    1000           0 :                 ereport(ERROR,
    1001             :                         (errcode(ERRCODE_PROTOCOL_VIOLATION),
    1002             :                          errmsg("malformed SCRAM message"),
    1003             :                          errdetail("Comma expected, but found character \"%s\".",
    1004             :                                    sanitize_char(*p))));
    1005          86 :             p++;
    1006          86 :             break;
    1007           0 :         case 'y':
    1008             : 
    1009             :             /*
    1010             :              * The client supports channel binding and thinks that the server
    1011             :              * does not.  In this case, the server must fail authentication if
    1012             :              * it supports channel binding.
    1013             :              */
    1014           0 :             if (state->channel_binding_in_use)
    1015           0 :                 ereport(ERROR,
    1016             :                         (errcode(ERRCODE_PROTOCOL_VIOLATION),
    1017             :                          errmsg("malformed SCRAM message"),
    1018             :                          errdetail("The client selected SCRAM-SHA-256-PLUS, but the SCRAM message does not include channel binding data.")));
    1019             : 
    1020             : #ifdef USE_SSL
    1021           0 :             if (state->port->ssl_in_use)
    1022           0 :                 ereport(ERROR,
    1023             :                         (errcode(ERRCODE_INVALID_AUTHORIZATION_SPECIFICATION),
    1024             :                          errmsg("SCRAM channel binding negotiation error"),
    1025             :                          errdetail("The client supports SCRAM channel binding but thinks the server does not.  "
    1026             :                                    "However, this server does support channel binding.")));
    1027             : #endif
    1028           0 :             p++;
    1029           0 :             if (*p != ',')
    1030           0 :                 ereport(ERROR,
    1031             :                         (errcode(ERRCODE_PROTOCOL_VIOLATION),
    1032             :                          errmsg("malformed SCRAM message"),
    1033             :                          errdetail("Comma expected, but found character \"%s\".",
    1034             :                                    sanitize_char(*p))));
    1035           0 :             p++;
    1036           0 :             break;
    1037          10 :         case 'p':
    1038             : 
    1039             :             /*
    1040             :              * The client requires channel binding.  Channel binding type
    1041             :              * follows, e.g., "p=tls-server-end-point".
    1042             :              */
    1043          10 :             if (!state->channel_binding_in_use)
    1044           0 :                 ereport(ERROR,
    1045             :                         (errcode(ERRCODE_PROTOCOL_VIOLATION),
    1046             :                          errmsg("malformed SCRAM message"),
    1047             :                          errdetail("The client selected SCRAM-SHA-256 without channel binding, but the SCRAM message includes channel binding data.")));
    1048             : 
    1049          10 :             channel_binding_type = read_attr_value(&p, 'p');
    1050             : 
    1051             :             /*
    1052             :              * The only channel binding type we support is
    1053             :              * tls-server-end-point.
    1054             :              */
    1055          10 :             if (strcmp(channel_binding_type, "tls-server-end-point") != 0)
    1056           0 :                 ereport(ERROR,
    1057             :                         (errcode(ERRCODE_PROTOCOL_VIOLATION),
    1058             :                          errmsg("unsupported SCRAM channel-binding type \"%s\"",
    1059             :                                 sanitize_str(channel_binding_type))));
    1060          10 :             break;
    1061           0 :         default:
    1062           0 :             ereport(ERROR,
    1063             :                     (errcode(ERRCODE_PROTOCOL_VIOLATION),
    1064             :                      errmsg("malformed SCRAM message"),
    1065             :                      errdetail("Unexpected channel-binding flag \"%s\".",
    1066             :                                sanitize_char(*p))));
    1067             :     }
    1068             : 
    1069             :     /*
    1070             :      * Forbid optional authzid (authorization identity).  We don't support it.
    1071             :      */
    1072          96 :     if (*p == 'a')
    1073           0 :         ereport(ERROR,
    1074             :                 (errcode(ERRCODE_FEATURE_NOT_SUPPORTED),
    1075             :                  errmsg("client uses authorization identity, but it is not supported")));
    1076          96 :     if (*p != ',')
    1077           0 :         ereport(ERROR,
    1078             :                 (errcode(ERRCODE_PROTOCOL_VIOLATION),
    1079             :                  errmsg("malformed SCRAM message"),
    1080             :                  errdetail("Unexpected attribute \"%s\" in client-first-message.",
    1081             :                            sanitize_char(*p))));
    1082          96 :     p++;
    1083             : 
    1084          96 :     state->client_first_message_bare = pstrdup(p);
    1085             : 
    1086             :     /*
    1087             :      * Any mandatory extensions would go here.  We don't support any.
    1088             :      *
    1089             :      * RFC 5802 specifies error code "e=extensions-not-supported" for this,
    1090             :      * but it can only be sent in the server-final message.  We prefer to fail
    1091             :      * immediately (which the RFC also allows).
    1092             :      */
    1093          96 :     if (*p == 'm')
    1094           0 :         ereport(ERROR,
    1095             :                 (errcode(ERRCODE_FEATURE_NOT_SUPPORTED),
    1096             :                  errmsg("client requires an unsupported SCRAM extension")));
    1097             : 
    1098             :     /*
    1099             :      * Read username.  Note: this is ignored.  We use the username from the
    1100             :      * startup message instead, still it is kept around if provided as it
    1101             :      * proves to be useful for debugging purposes.
    1102             :      */
    1103          96 :     state->client_username = read_attr_value(&p, 'n');
    1104             : 
    1105             :     /* read nonce and check that it is made of only printable characters */
    1106          96 :     state->client_nonce = read_attr_value(&p, 'r');
    1107          96 :     if (!is_scram_printable(state->client_nonce))
    1108           0 :         ereport(ERROR,
    1109             :                 (errcode(ERRCODE_PROTOCOL_VIOLATION),
    1110             :                  errmsg("non-printable characters in SCRAM nonce")));
    1111             : 
    1112             :     /*
    1113             :      * There can be any number of optional extensions after this.  We don't
    1114             :      * support any extensions, so ignore them.
    1115             :      */
    1116          96 :     while (*p != '\0')
    1117           0 :         read_any_attr(&p, NULL);
    1118             : 
    1119             :     /* success! */
    1120          96 : }
    1121             : 
    1122             : /*
    1123             :  * Verify the final nonce contained in the last message received from
    1124             :  * client in an exchange.
    1125             :  */
    1126             : static bool
    1127          96 : verify_final_nonce(scram_state *state)
    1128             : {
    1129          96 :     int         client_nonce_len = strlen(state->client_nonce);
    1130          96 :     int         server_nonce_len = strlen(state->server_nonce);
    1131          96 :     int         final_nonce_len = strlen(state->client_final_nonce);
    1132             : 
    1133          96 :     if (final_nonce_len != client_nonce_len + server_nonce_len)
    1134           0 :         return false;
    1135          96 :     if (memcmp(state->client_final_nonce, state->client_nonce, client_nonce_len) != 0)
    1136           0 :         return false;
    1137          96 :     if (memcmp(state->client_final_nonce + client_nonce_len, state->server_nonce, server_nonce_len) != 0)
    1138           0 :         return false;
    1139             : 
    1140          96 :     return true;
    1141             : }
    1142             : 
    1143             : /*
    1144             :  * Verify the client proof contained in the last message received from
    1145             :  * client in an exchange.  Returns true if the verification is a success,
    1146             :  * or false for a failure.
    1147             :  */
    1148             : static bool
    1149          96 : verify_client_proof(scram_state *state)
    1150             : {
    1151             :     uint8       ClientSignature[SCRAM_MAX_KEY_LEN];
    1152             :     uint8       client_StoredKey[SCRAM_MAX_KEY_LEN];
    1153          96 :     pg_hmac_ctx *ctx = pg_hmac_create(state->hash_type);
    1154             :     int         i;
    1155          96 :     const char *errstr = NULL;
    1156             : 
    1157             :     /*
    1158             :      * Calculate ClientSignature.  Note that we don't log directly a failure
    1159             :      * here even when processing the calculations as this could involve a mock
    1160             :      * authentication.
    1161             :      */
    1162         192 :     if (pg_hmac_init(ctx, state->StoredKey, state->key_length) < 0 ||
    1163          96 :         pg_hmac_update(ctx,
    1164          96 :                        (uint8 *) state->client_first_message_bare,
    1165         192 :                        strlen(state->client_first_message_bare)) < 0 ||
    1166         192 :         pg_hmac_update(ctx, (uint8 *) ",", 1) < 0 ||
    1167          96 :         pg_hmac_update(ctx,
    1168          96 :                        (uint8 *) state->server_first_message,
    1169         192 :                        strlen(state->server_first_message)) < 0 ||
    1170         192 :         pg_hmac_update(ctx, (uint8 *) ",", 1) < 0 ||
    1171          96 :         pg_hmac_update(ctx,
    1172          96 :                        (uint8 *) state->client_final_message_without_proof,
    1173         192 :                        strlen(state->client_final_message_without_proof)) < 0 ||
    1174          96 :         pg_hmac_final(ctx, ClientSignature, state->key_length) < 0)
    1175             :     {
    1176           0 :         elog(ERROR, "could not calculate client signature: %s",
    1177             :              pg_hmac_error(ctx));
    1178             :     }
    1179             : 
    1180          96 :     pg_hmac_free(ctx);
    1181             : 
    1182             :     /* Extract the ClientKey that the client calculated from the proof */
    1183        3168 :     for (i = 0; i < state->key_length; i++)
    1184        3072 :         state->ClientKey[i] = state->ClientProof[i] ^ ClientSignature[i];
    1185             : 
    1186             :     /* Hash it one more time, and compare with StoredKey */
    1187          96 :     if (scram_H(state->ClientKey, state->hash_type, state->key_length,
    1188             :                 client_StoredKey, &errstr) < 0)
    1189           0 :         elog(ERROR, "could not hash stored key: %s", errstr);
    1190             : 
    1191          96 :     if (memcmp(client_StoredKey, state->StoredKey, state->key_length) != 0)
    1192          10 :         return false;
    1193             : 
    1194          86 :     return true;
    1195             : }
    1196             : 
    1197             : /*
    1198             :  * Build the first server-side message sent to the client in a SCRAM
    1199             :  * communication exchange.
    1200             :  */
    1201             : static char *
    1202          96 : build_server_first_message(scram_state *state)
    1203             : {
    1204             :     /*------
    1205             :      * The syntax for the server-first-message is: (RFC 5802)
    1206             :      *
    1207             :      * server-first-message =
    1208             :      *                   [reserved-mext ","] nonce "," salt ","
    1209             :      *                   iteration-count ["," extensions]
    1210             :      *
    1211             :      * nonce           = "r=" c-nonce [s-nonce]
    1212             :      *                   ;; Second part provided by server.
    1213             :      *
    1214             :      * c-nonce         = printable
    1215             :      *
    1216             :      * s-nonce         = printable
    1217             :      *
    1218             :      * salt            = "s=" base64
    1219             :      *
    1220             :      * iteration-count = "i=" posit-number
    1221             :      *                   ;; A positive number.
    1222             :      *
    1223             :      * Example:
    1224             :      *
    1225             :      * r=fyko+d2lbbFgONRv9qkxdawL3rfcNHYJY1ZVvWVs7j,s=QSXCR+Q6sek8bf92,i=4096
    1226             :      *------
    1227             :      */
    1228             : 
    1229             :     /*
    1230             :      * Per the spec, the nonce may consist of any printable ASCII characters.
    1231             :      * For convenience, however, we don't use the whole range available,
    1232             :      * rather, we generate some random bytes, and base64 encode them.
    1233             :      */
    1234             :     char        raw_nonce[SCRAM_RAW_NONCE_LEN];
    1235             :     int         encoded_len;
    1236             : 
    1237          96 :     if (!pg_strong_random(raw_nonce, SCRAM_RAW_NONCE_LEN))
    1238           0 :         ereport(ERROR,
    1239             :                 (errcode(ERRCODE_INTERNAL_ERROR),
    1240             :                  errmsg("could not generate random nonce")));
    1241             : 
    1242          96 :     encoded_len = pg_b64_enc_len(SCRAM_RAW_NONCE_LEN);
    1243             :     /* don't forget the zero-terminator */
    1244          96 :     state->server_nonce = palloc(encoded_len + 1);
    1245          96 :     encoded_len = pg_b64_encode(raw_nonce, SCRAM_RAW_NONCE_LEN,
    1246             :                                 state->server_nonce, encoded_len);
    1247          96 :     if (encoded_len < 0)
    1248           0 :         ereport(ERROR,
    1249             :                 (errcode(ERRCODE_INTERNAL_ERROR),
    1250             :                  errmsg("could not encode random nonce")));
    1251          96 :     state->server_nonce[encoded_len] = '\0';
    1252             : 
    1253          96 :     state->server_first_message =
    1254          96 :         psprintf("r=%s%s,s=%s,i=%d",
    1255             :                  state->client_nonce, state->server_nonce,
    1256             :                  state->salt, state->iterations);
    1257             : 
    1258          96 :     return pstrdup(state->server_first_message);
    1259             : }
    1260             : 
    1261             : 
    1262             : /*
    1263             :  * Read and parse the final message received from client.
    1264             :  */
    1265             : static void
    1266          96 : read_client_final_message(scram_state *state, const char *input)
    1267             : {
    1268             :     char        attr;
    1269             :     char       *channel_binding;
    1270             :     char       *value;
    1271             :     char       *begin,
    1272             :                *proof;
    1273             :     char       *p;
    1274             :     char       *client_proof;
    1275             :     int         client_proof_len;
    1276             : 
    1277          96 :     begin = p = pstrdup(input);
    1278             : 
    1279             :     /*------
    1280             :      * The syntax for the server-first-message is: (RFC 5802)
    1281             :      *
    1282             :      * gs2-header      = gs2-cbind-flag "," [ authzid ] ","
    1283             :      *                   ;; GS2 header for SCRAM
    1284             :      *                   ;; (the actual GS2 header includes an optional
    1285             :      *                   ;; flag to indicate that the GSS mechanism is not
    1286             :      *                   ;; "standard", but since SCRAM is "standard", we
    1287             :      *                   ;; don't include that flag).
    1288             :      *
    1289             :      * cbind-input   = gs2-header [ cbind-data ]
    1290             :      *                   ;; cbind-data MUST be present for
    1291             :      *                   ;; gs2-cbind-flag of "p" and MUST be absent
    1292             :      *                   ;; for "y" or "n".
    1293             :      *
    1294             :      * channel-binding = "c=" base64
    1295             :      *                   ;; base64 encoding of cbind-input.
    1296             :      *
    1297             :      * proof           = "p=" base64
    1298             :      *
    1299             :      * client-final-message-without-proof =
    1300             :      *                   channel-binding "," nonce [","
    1301             :      *                   extensions]
    1302             :      *
    1303             :      * client-final-message =
    1304             :      *                   client-final-message-without-proof "," proof
    1305             :      *------
    1306             :      */
    1307             : 
    1308             :     /*
    1309             :      * Read channel binding.  This repeats the channel-binding flags and is
    1310             :      * then followed by the actual binding data depending on the type.
    1311             :      */
    1312          96 :     channel_binding = read_attr_value(&p, 'c');
    1313          96 :     if (state->channel_binding_in_use)
    1314             :     {
    1315             : #ifdef USE_SSL
    1316          10 :         const char *cbind_data = NULL;
    1317          10 :         size_t      cbind_data_len = 0;
    1318             :         size_t      cbind_header_len;
    1319             :         char       *cbind_input;
    1320             :         size_t      cbind_input_len;
    1321             :         char       *b64_message;
    1322             :         int         b64_message_len;
    1323             : 
    1324             :         Assert(state->cbind_flag == 'p');
    1325             : 
    1326             :         /* Fetch hash data of server's SSL certificate */
    1327          10 :         cbind_data = be_tls_get_certificate_hash(state->port,
    1328             :                                                  &cbind_data_len);
    1329             : 
    1330             :         /* should not happen */
    1331          10 :         if (cbind_data == NULL || cbind_data_len == 0)
    1332           0 :             elog(ERROR, "could not get server certificate hash");
    1333             : 
    1334          10 :         cbind_header_len = strlen("p=tls-server-end-point,,");    /* p=type,, */
    1335          10 :         cbind_input_len = cbind_header_len + cbind_data_len;
    1336          10 :         cbind_input = palloc(cbind_input_len);
    1337          10 :         snprintf(cbind_input, cbind_input_len, "p=tls-server-end-point,,");
    1338          10 :         memcpy(cbind_input + cbind_header_len, cbind_data, cbind_data_len);
    1339             : 
    1340          10 :         b64_message_len = pg_b64_enc_len(cbind_input_len);
    1341             :         /* don't forget the zero-terminator */
    1342          10 :         b64_message = palloc(b64_message_len + 1);
    1343          10 :         b64_message_len = pg_b64_encode(cbind_input, cbind_input_len,
    1344             :                                         b64_message, b64_message_len);
    1345          10 :         if (b64_message_len < 0)
    1346           0 :             elog(ERROR, "could not encode channel binding data");
    1347          10 :         b64_message[b64_message_len] = '\0';
    1348             : 
    1349             :         /*
    1350             :          * Compare the value sent by the client with the value expected by the
    1351             :          * server.
    1352             :          */
    1353          10 :         if (strcmp(channel_binding, b64_message) != 0)
    1354           0 :             ereport(ERROR,
    1355             :                     (errcode(ERRCODE_INVALID_AUTHORIZATION_SPECIFICATION),
    1356             :                      errmsg("SCRAM channel binding check failed")));
    1357             : #else
    1358             :         /* shouldn't happen, because we checked this earlier already */
    1359             :         elog(ERROR, "channel binding not supported by this build");
    1360             : #endif
    1361             :     }
    1362             :     else
    1363             :     {
    1364             :         /*
    1365             :          * If we are not using channel binding, the binding data is expected
    1366             :          * to always be "biws", which is "n,," base64-encoded, or "eSws",
    1367             :          * which is "y,,".  We also have to check whether the flag is the same
    1368             :          * one that the client originally sent.
    1369             :          */
    1370          86 :         if (!(strcmp(channel_binding, "biws") == 0 && state->cbind_flag == 'n') &&
    1371           0 :             !(strcmp(channel_binding, "eSws") == 0 && state->cbind_flag == 'y'))
    1372           0 :             ereport(ERROR,
    1373             :                     (errcode(ERRCODE_PROTOCOL_VIOLATION),
    1374             :                      errmsg("unexpected SCRAM channel-binding attribute in client-final-message")));
    1375             :     }
    1376             : 
    1377          96 :     state->client_final_nonce = read_attr_value(&p, 'r');
    1378             : 
    1379             :     /* ignore optional extensions, read until we find "p" attribute */
    1380             :     do
    1381             :     {
    1382          96 :         proof = p - 1;
    1383          96 :         value = read_any_attr(&p, &attr);
    1384          96 :     } while (attr != 'p');
    1385             : 
    1386          96 :     client_proof_len = pg_b64_dec_len(strlen(value));
    1387          96 :     client_proof = palloc(client_proof_len);
    1388          96 :     if (pg_b64_decode(value, strlen(value), client_proof,
    1389          96 :                       client_proof_len) != state->key_length)
    1390           0 :         ereport(ERROR,
    1391             :                 (errcode(ERRCODE_PROTOCOL_VIOLATION),
    1392             :                  errmsg("malformed SCRAM message"),
    1393             :                  errdetail("Malformed proof in client-final-message.")));
    1394          96 :     memcpy(state->ClientProof, client_proof, state->key_length);
    1395          96 :     pfree(client_proof);
    1396             : 
    1397          96 :     if (*p != '\0')
    1398           0 :         ereport(ERROR,
    1399             :                 (errcode(ERRCODE_PROTOCOL_VIOLATION),
    1400             :                  errmsg("malformed SCRAM message"),
    1401             :                  errdetail("Garbage found at the end of client-final-message.")));
    1402             : 
    1403          96 :     state->client_final_message_without_proof = palloc(proof - begin + 1);
    1404          96 :     memcpy(state->client_final_message_without_proof, input, proof - begin);
    1405          96 :     state->client_final_message_without_proof[proof - begin] = '\0';
    1406          96 : }
    1407             : 
    1408             : /*
    1409             :  * Build the final server-side message of an exchange.
    1410             :  */
    1411             : static char *
    1412          86 : build_server_final_message(scram_state *state)
    1413             : {
    1414             :     uint8       ServerSignature[SCRAM_MAX_KEY_LEN];
    1415             :     char       *server_signature_base64;
    1416             :     int         siglen;
    1417          86 :     pg_hmac_ctx *ctx = pg_hmac_create(state->hash_type);
    1418             : 
    1419             :     /* calculate ServerSignature */
    1420         172 :     if (pg_hmac_init(ctx, state->ServerKey, state->key_length) < 0 ||
    1421          86 :         pg_hmac_update(ctx,
    1422          86 :                        (uint8 *) state->client_first_message_bare,
    1423         172 :                        strlen(state->client_first_message_bare)) < 0 ||
    1424         172 :         pg_hmac_update(ctx, (uint8 *) ",", 1) < 0 ||
    1425          86 :         pg_hmac_update(ctx,
    1426          86 :                        (uint8 *) state->server_first_message,
    1427         172 :                        strlen(state->server_first_message)) < 0 ||
    1428         172 :         pg_hmac_update(ctx, (uint8 *) ",", 1) < 0 ||
    1429          86 :         pg_hmac_update(ctx,
    1430          86 :                        (uint8 *) state->client_final_message_without_proof,
    1431         172 :                        strlen(state->client_final_message_without_proof)) < 0 ||
    1432          86 :         pg_hmac_final(ctx, ServerSignature, state->key_length) < 0)
    1433             :     {
    1434           0 :         elog(ERROR, "could not calculate server signature: %s",
    1435             :              pg_hmac_error(ctx));
    1436             :     }
    1437             : 
    1438          86 :     pg_hmac_free(ctx);
    1439             : 
    1440          86 :     siglen = pg_b64_enc_len(state->key_length);
    1441             :     /* don't forget the zero-terminator */
    1442          86 :     server_signature_base64 = palloc(siglen + 1);
    1443          86 :     siglen = pg_b64_encode((const char *) ServerSignature,
    1444             :                            state->key_length, server_signature_base64,
    1445             :                            siglen);
    1446          86 :     if (siglen < 0)
    1447           0 :         elog(ERROR, "could not encode server signature");
    1448          86 :     server_signature_base64[siglen] = '\0';
    1449             : 
    1450             :     /*------
    1451             :      * The syntax for the server-final-message is: (RFC 5802)
    1452             :      *
    1453             :      * verifier        = "v=" base64
    1454             :      *                   ;; base-64 encoded ServerSignature.
    1455             :      *
    1456             :      * server-final-message = (server-error / verifier)
    1457             :      *                   ["," extensions]
    1458             :      *
    1459             :      *------
    1460             :      */
    1461          86 :     return psprintf("v=%s", server_signature_base64);
    1462             : }
    1463             : 
    1464             : 
    1465             : /*
    1466             :  * Deterministically generate salt for mock authentication, using a SHA256
    1467             :  * hash based on the username and a cluster-level secret key.  Returns a
    1468             :  * pointer to a static buffer of size SCRAM_DEFAULT_SALT_LEN, or NULL.
    1469             :  */
    1470             : static char *
    1471           2 : scram_mock_salt(const char *username, pg_cryptohash_type hash_type,
    1472             :                 int key_length)
    1473             : {
    1474             :     pg_cryptohash_ctx *ctx;
    1475             :     static uint8 sha_digest[SCRAM_MAX_KEY_LEN];
    1476           2 :     char       *mock_auth_nonce = GetMockAuthenticationNonce();
    1477             : 
    1478             :     /*
    1479             :      * Generate salt using a SHA256 hash of the username and the cluster's
    1480             :      * mock authentication nonce.  (This works as long as the salt length is
    1481             :      * not larger than the SHA256 digest length.  If the salt is smaller, the
    1482             :      * caller will just ignore the extra data.)
    1483             :      */
    1484             :     StaticAssertDecl(PG_SHA256_DIGEST_LENGTH >= SCRAM_DEFAULT_SALT_LEN,
    1485             :                      "salt length greater than SHA256 digest length");
    1486             : 
    1487             :     /*
    1488             :      * This may be worth refreshing if support for more hash methods is\
    1489             :      * added.
    1490             :      */
    1491             :     Assert(hash_type == PG_SHA256);
    1492             : 
    1493           2 :     ctx = pg_cryptohash_create(hash_type);
    1494           4 :     if (pg_cryptohash_init(ctx) < 0 ||
    1495           4 :         pg_cryptohash_update(ctx, (uint8 *) username, strlen(username)) < 0 ||
    1496           4 :         pg_cryptohash_update(ctx, (uint8 *) mock_auth_nonce, MOCK_AUTH_NONCE_LEN) < 0 ||
    1497           2 :         pg_cryptohash_final(ctx, sha_digest, key_length) < 0)
    1498             :     {
    1499           0 :         pg_cryptohash_free(ctx);
    1500           0 :         return NULL;
    1501             :     }
    1502           2 :     pg_cryptohash_free(ctx);
    1503             : 
    1504           2 :     return (char *) sha_digest;
    1505             : }

Generated by: LCOV version 1.14