LCOV - code coverage report
Current view: top level - src/backend/libpq - auth-scram.c (source / functions) Hit Total Coverage
Test: PostgreSQL 16beta1 Lines: 303 391 77.5 %
Date: 2023-05-30 23:12:14 Functions: 17 19 89.5 %
Legend: Lines: hit not hit

          Line data    Source code
       1             : /*-------------------------------------------------------------------------
       2             :  *
       3             :  * auth-scram.c
       4             :  *    Server-side implementation of the SASL SCRAM-SHA-256 mechanism.
       5             :  *
       6             :  * See the following RFCs for more details:
       7             :  * - RFC 5802: https://tools.ietf.org/html/rfc5802
       8             :  * - RFC 5803: https://tools.ietf.org/html/rfc5803
       9             :  * - RFC 7677: https://tools.ietf.org/html/rfc7677
      10             :  *
      11             :  * Here are some differences:
      12             :  *
      13             :  * - Username from the authentication exchange is not used. The client
      14             :  *   should send an empty string as the username.
      15             :  *
      16             :  * - If the password isn't valid UTF-8, or contains characters prohibited
      17             :  *   by the SASLprep profile, we skip the SASLprep pre-processing and use
      18             :  *   the raw bytes in calculating the hash.
      19             :  *
      20             :  * - If channel binding is used, the channel binding type is always
      21             :  *   "tls-server-end-point".  The spec says the default is "tls-unique"
      22             :  *   (RFC 5802, section 6.1. Default Channel Binding), but there are some
      23             :  *   problems with that.  Firstly, not all SSL libraries provide an API to
      24             :  *   get the TLS Finished message, required to use "tls-unique".  Secondly,
      25             :  *   "tls-unique" is not specified for TLS v1.3, and as of this writing,
      26             :  *   it's not clear if there will be a replacement.  We could support both
      27             :  *   "tls-server-end-point" and "tls-unique", but for our use case,
      28             :  *   "tls-unique" doesn't really have any advantages.  The main advantage
      29             :  *   of "tls-unique" would be that it works even if the server doesn't
      30             :  *   have a certificate, but PostgreSQL requires a server certificate
      31             :  *   whenever SSL is used, anyway.
      32             :  *
      33             :  *
      34             :  * The password stored in pg_authid consists of the iteration count, salt,
      35             :  * StoredKey and ServerKey.
      36             :  *
      37             :  * SASLprep usage
      38             :  * --------------
      39             :  *
      40             :  * One notable difference to the SCRAM specification is that while the
      41             :  * specification dictates that the password is in UTF-8, and prohibits
      42             :  * certain characters, we are more lenient.  If the password isn't a valid
      43             :  * UTF-8 string, or contains prohibited characters, the raw bytes are used
      44             :  * to calculate the hash instead, without SASLprep processing.  This is
      45             :  * because PostgreSQL supports other encodings too, and the encoding being
      46             :  * used during authentication is undefined (client_encoding isn't set until
      47             :  * after authentication).  In effect, we try to interpret the password as
      48             :  * UTF-8 and apply SASLprep processing, but if it looks invalid, we assume
      49             :  * that it's in some other encoding.
      50             :  *
      51             :  * In the worst case, we misinterpret a password that's in a different
      52             :  * encoding as being Unicode, because it happens to consists entirely of
      53             :  * valid UTF-8 bytes, and we apply Unicode normalization to it.  As long
      54             :  * as we do that consistently, that will not lead to failed logins.
      55             :  * Fortunately, the UTF-8 byte sequences that are ignored by SASLprep
      56             :  * don't correspond to any commonly used characters in any of the other
      57             :  * supported encodings, so it should not lead to any significant loss in
      58             :  * entropy, even if the normalization is incorrectly applied to a
      59             :  * non-UTF-8 password.
      60             :  *
      61             :  * Error handling
      62             :  * --------------
      63             :  *
      64             :  * Don't reveal user information to an unauthenticated client.  We don't
      65             :  * want an attacker to be able to probe whether a particular username is
      66             :  * valid.  In SCRAM, the server has to read the salt and iteration count
      67             :  * from the user's stored secret, and send it to the client.  To avoid
      68             :  * revealing whether a user exists, when the client tries to authenticate
      69             :  * with a username that doesn't exist, or doesn't have a valid SCRAM
      70             :  * secret in pg_authid, we create a fake salt and iteration count
      71             :  * on-the-fly, and proceed with the authentication with that.  In the end,
      72             :  * we'll reject the attempt, as if an incorrect password was given.  When
      73             :  * we are performing a "mock" authentication, the 'doomed' flag in
      74             :  * scram_state is set.
      75             :  *
      76             :  * In the error messages, avoid printing strings from the client, unless
      77             :  * you check that they are pure ASCII.  We don't want an unauthenticated
      78             :  * attacker to be able to spam the logs with characters that are not valid
      79             :  * to the encoding being used, whatever that is.  We cannot avoid that in
      80             :  * general, after logging in, but let's do what we can here.
      81             :  *
      82             :  *
      83             :  * Portions Copyright (c) 1996-2023, PostgreSQL Global Development Group
      84             :  * Portions Copyright (c) 1994, Regents of the University of California
      85             :  *
      86             :  * src/backend/libpq/auth-scram.c
      87             :  *
      88             :  *-------------------------------------------------------------------------
      89             :  */
      90             : #include "postgres.h"
      91             : 
      92             : #include <unistd.h>
      93             : 
      94             : #include "access/xlog.h"
      95             : #include "catalog/pg_authid.h"
      96             : #include "catalog/pg_control.h"
      97             : #include "common/base64.h"
      98             : #include "common/hmac.h"
      99             : #include "common/saslprep.h"
     100             : #include "common/scram-common.h"
     101             : #include "common/sha2.h"
     102             : #include "libpq/auth.h"
     103             : #include "libpq/crypt.h"
     104             : #include "libpq/sasl.h"
     105             : #include "libpq/scram.h"
     106             : #include "miscadmin.h"
     107             : #include "utils/builtins.h"
     108             : #include "utils/timestamp.h"
     109             : 
     110             : static void scram_get_mechanisms(Port *port, StringInfo buf);
     111             : static void *scram_init(Port *port, const char *selected_mech,
     112             :                         const char *shadow_pass);
     113             : static int  scram_exchange(void *opaq, const char *input, int inputlen,
     114             :                            char **output, int *outputlen,
     115             :                            const char **logdetail);
     116             : 
     117             : /* Mechanism declaration */
     118             : const pg_be_sasl_mech pg_be_scram_mech = {
     119             :     scram_get_mechanisms,
     120             :     scram_init,
     121             :     scram_exchange
     122             : };
     123             : 
     124             : /*
     125             :  * Status data for a SCRAM authentication exchange.  This should be kept
     126             :  * internal to this file.
     127             :  */
     128             : typedef enum
     129             : {
     130             :     SCRAM_AUTH_INIT,
     131             :     SCRAM_AUTH_SALT_SENT,
     132             :     SCRAM_AUTH_FINISHED
     133             : } scram_state_enum;
     134             : 
     135             : typedef struct
     136             : {
     137             :     scram_state_enum state;
     138             : 
     139             :     const char *username;       /* username from startup packet */
     140             : 
     141             :     Port       *port;
     142             :     bool        channel_binding_in_use;
     143             : 
     144             :     /* State data depending on the hash type */
     145             :     pg_cryptohash_type hash_type;
     146             :     int         key_length;
     147             : 
     148             :     int         iterations;
     149             :     char       *salt;           /* base64-encoded */
     150             :     uint8       StoredKey[SCRAM_MAX_KEY_LEN];
     151             :     uint8       ServerKey[SCRAM_MAX_KEY_LEN];
     152             : 
     153             :     /* Fields of the first message from client */
     154             :     char        cbind_flag;
     155             :     char       *client_first_message_bare;
     156             :     char       *client_username;
     157             :     char       *client_nonce;
     158             : 
     159             :     /* Fields from the last message from client */
     160             :     char       *client_final_message_without_proof;
     161             :     char       *client_final_nonce;
     162             :     char        ClientProof[SCRAM_MAX_KEY_LEN];
     163             : 
     164             :     /* Fields generated in the server */
     165             :     char       *server_first_message;
     166             :     char       *server_nonce;
     167             : 
     168             :     /*
     169             :      * If something goes wrong during the authentication, or we are performing
     170             :      * a "mock" authentication (see comments at top of file), the 'doomed'
     171             :      * flag is set.  A reason for the failure, for the server log, is put in
     172             :      * 'logdetail'.
     173             :      */
     174             :     bool        doomed;
     175             :     char       *logdetail;
     176             : } scram_state;
     177             : 
     178             : static void read_client_first_message(scram_state *state, const char *input);
     179             : static void read_client_final_message(scram_state *state, const char *input);
     180             : static char *build_server_first_message(scram_state *state);
     181             : static char *build_server_final_message(scram_state *state);
     182             : static bool verify_client_proof(scram_state *state);
     183             : static bool verify_final_nonce(scram_state *state);
     184             : static void mock_scram_secret(const char *username, pg_cryptohash_type *hash_type,
     185             :                               int *iterations, int *key_length, char **salt,
     186             :                               uint8 *stored_key, uint8 *server_key);
     187             : static bool is_scram_printable(char *p);
     188             : static char *sanitize_char(char c);
     189             : static char *sanitize_str(const char *s);
     190             : static char *scram_mock_salt(const char *username,
     191             :                              pg_cryptohash_type hash_type,
     192             :                              int key_length);
     193             : 
     194             : /*
     195             :  * The number of iterations to use when generating new secrets.
     196             :  */
     197             : int         scram_sha_256_iterations = SCRAM_SHA_256_DEFAULT_ITERATIONS;
     198             : 
     199             : /*
     200             :  * Get a list of SASL mechanisms that this module supports.
     201             :  *
     202             :  * For the convenience of building the FE/BE packet that lists the
     203             :  * mechanisms, the names are appended to the given StringInfo buffer,
     204             :  * separated by '\0' bytes.
     205             :  */
     206             : static void
     207          76 : scram_get_mechanisms(Port *port, StringInfo buf)
     208             : {
     209             :     /*
     210             :      * Advertise the mechanisms in decreasing order of importance.  So the
     211             :      * channel-binding variants go first, if they are supported.  Channel
     212             :      * binding is only supported with SSL, and only if the SSL implementation
     213             :      * has a function to get the certificate's hash.
     214             :      */
     215             : #ifdef HAVE_BE_TLS_GET_CERTIFICATE_HASH
     216          76 :     if (port->ssl_in_use)
     217             :     {
     218          12 :         appendStringInfoString(buf, SCRAM_SHA_256_PLUS_NAME);
     219          12 :         appendStringInfoChar(buf, '\0');
     220             :     }
     221             : #endif
     222          76 :     appendStringInfoString(buf, SCRAM_SHA_256_NAME);
     223          76 :     appendStringInfoChar(buf, '\0');
     224          76 : }
     225             : 
     226             : /*
     227             :  * Initialize a new SCRAM authentication exchange status tracker.  This
     228             :  * needs to be called before doing any exchange.  It will be filled later
     229             :  * after the beginning of the exchange with authentication information.
     230             :  *
     231             :  * 'selected_mech' identifies the SASL mechanism that the client selected.
     232             :  * It should be one of the mechanisms that we support, as returned by
     233             :  * scram_get_mechanisms().
     234             :  *
     235             :  * 'shadow_pass' is the role's stored secret, from pg_authid.rolpassword.
     236             :  * The username was provided by the client in the startup message, and is
     237             :  * available in port->user_name.  If 'shadow_pass' is NULL, we still perform
     238             :  * an authentication exchange, but it will fail, as if an incorrect password
     239             :  * was given.
     240             :  */
     241             : static void *
     242          64 : scram_init(Port *port, const char *selected_mech, const char *shadow_pass)
     243             : {
     244             :     scram_state *state;
     245             :     bool        got_secret;
     246             : 
     247          64 :     state = (scram_state *) palloc0(sizeof(scram_state));
     248          64 :     state->port = port;
     249          64 :     state->state = SCRAM_AUTH_INIT;
     250             : 
     251             :     /*
     252             :      * Parse the selected mechanism.
     253             :      *
     254             :      * Note that if we don't support channel binding, either because the SSL
     255             :      * implementation doesn't support it or we're not using SSL at all, we
     256             :      * would not have advertised the PLUS variant in the first place.  If the
     257             :      * client nevertheless tries to select it, it's a protocol violation like
     258             :      * selecting any other SASL mechanism we don't support.
     259             :      */
     260             : #ifdef HAVE_BE_TLS_GET_CERTIFICATE_HASH
     261          64 :     if (strcmp(selected_mech, SCRAM_SHA_256_PLUS_NAME) == 0 && port->ssl_in_use)
     262           8 :         state->channel_binding_in_use = true;
     263             :     else
     264             : #endif
     265          56 :     if (strcmp(selected_mech, SCRAM_SHA_256_NAME) == 0)
     266          56 :         state->channel_binding_in_use = false;
     267             :     else
     268           0 :         ereport(ERROR,
     269             :                 (errcode(ERRCODE_PROTOCOL_VIOLATION),
     270             :                  errmsg("client selected an invalid SASL authentication mechanism")));
     271             : 
     272             :     /*
     273             :      * Parse the stored secret.
     274             :      */
     275          64 :     if (shadow_pass)
     276             :     {
     277          64 :         int         password_type = get_password_type(shadow_pass);
     278             : 
     279          64 :         if (password_type == PASSWORD_TYPE_SCRAM_SHA_256)
     280             :         {
     281          62 :             if (parse_scram_secret(shadow_pass, &state->iterations,
     282             :                                    &state->hash_type, &state->key_length,
     283             :                                    &state->salt,
     284          62 :                                    state->StoredKey,
     285          62 :                                    state->ServerKey))
     286          62 :                 got_secret = true;
     287             :             else
     288             :             {
     289             :                 /*
     290             :                  * The password looked like a SCRAM secret, but could not be
     291             :                  * parsed.
     292             :                  */
     293           0 :                 ereport(LOG,
     294             :                         (errmsg("invalid SCRAM secret for user \"%s\"",
     295             :                                 state->port->user_name)));
     296           0 :                 got_secret = false;
     297             :             }
     298             :         }
     299             :         else
     300             :         {
     301             :             /*
     302             :              * The user doesn't have SCRAM secret. (You cannot do SCRAM
     303             :              * authentication with an MD5 hash.)
     304             :              */
     305           4 :             state->logdetail = psprintf(_("User \"%s\" does not have a valid SCRAM secret."),
     306           2 :                                         state->port->user_name);
     307           2 :             got_secret = false;
     308             :         }
     309             :     }
     310             :     else
     311             :     {
     312             :         /*
     313             :          * The caller requested us to perform a dummy authentication.  This is
     314             :          * considered normal, since the caller requested it, so don't set log
     315             :          * detail.
     316             :          */
     317           0 :         got_secret = false;
     318             :     }
     319             : 
     320             :     /*
     321             :      * If the user did not have a valid SCRAM secret, we still go through the
     322             :      * motions with a mock one, and fail as if the client supplied an
     323             :      * incorrect password.  This is to avoid revealing information to an
     324             :      * attacker.
     325             :      */
     326          64 :     if (!got_secret)
     327             :     {
     328           2 :         mock_scram_secret(state->port->user_name, &state->hash_type,
     329             :                           &state->iterations, &state->key_length,
     330             :                           &state->salt,
     331           2 :                           state->StoredKey, state->ServerKey);
     332           2 :         state->doomed = true;
     333             :     }
     334             : 
     335          64 :     return state;
     336             : }
     337             : 
     338             : /*
     339             :  * Continue a SCRAM authentication exchange.
     340             :  *
     341             :  * 'input' is the SCRAM payload sent by the client.  On the first call,
     342             :  * 'input' contains the "Initial Client Response" that the client sent as
     343             :  * part of the SASLInitialResponse message, or NULL if no Initial Client
     344             :  * Response was given.  (The SASL specification distinguishes between an
     345             :  * empty response and non-existing one.)  On subsequent calls, 'input'
     346             :  * cannot be NULL.  For convenience in this function, the caller must
     347             :  * ensure that there is a null terminator at input[inputlen].
     348             :  *
     349             :  * The next message to send to client is saved in 'output', for a length
     350             :  * of 'outputlen'.  In the case of an error, optionally store a palloc'd
     351             :  * string at *logdetail that will be sent to the postmaster log (but not
     352             :  * the client).
     353             :  */
     354             : static int
     355         128 : scram_exchange(void *opaq, const char *input, int inputlen,
     356             :                char **output, int *outputlen, const char **logdetail)
     357             : {
     358         128 :     scram_state *state = (scram_state *) opaq;
     359             :     int         result;
     360             : 
     361         128 :     *output = NULL;
     362             : 
     363             :     /*
     364             :      * If the client didn't include an "Initial Client Response" in the
     365             :      * SASLInitialResponse message, send an empty challenge, to which the
     366             :      * client will respond with the same data that usually comes in the
     367             :      * Initial Client Response.
     368             :      */
     369         128 :     if (input == NULL)
     370             :     {
     371             :         Assert(state->state == SCRAM_AUTH_INIT);
     372             : 
     373           0 :         *output = pstrdup("");
     374           0 :         *outputlen = 0;
     375           0 :         return PG_SASL_EXCHANGE_CONTINUE;
     376             :     }
     377             : 
     378             :     /*
     379             :      * Check that the input length agrees with the string length of the input.
     380             :      * We can ignore inputlen after this.
     381             :      */
     382         128 :     if (inputlen == 0)
     383           0 :         ereport(ERROR,
     384             :                 (errcode(ERRCODE_PROTOCOL_VIOLATION),
     385             :                  errmsg("malformed SCRAM message"),
     386             :                  errdetail("The message is empty.")));
     387         128 :     if (inputlen != strlen(input))
     388           0 :         ereport(ERROR,
     389             :                 (errcode(ERRCODE_PROTOCOL_VIOLATION),
     390             :                  errmsg("malformed SCRAM message"),
     391             :                  errdetail("Message length does not match input length.")));
     392             : 
     393         128 :     switch (state->state)
     394             :     {
     395          64 :         case SCRAM_AUTH_INIT:
     396             : 
     397             :             /*
     398             :              * Initialization phase.  Receive the first message from client
     399             :              * and be sure that it parsed correctly.  Then send the challenge
     400             :              * to the client.
     401             :              */
     402          64 :             read_client_first_message(state, input);
     403             : 
     404             :             /* prepare message to send challenge */
     405          64 :             *output = build_server_first_message(state);
     406             : 
     407          64 :             state->state = SCRAM_AUTH_SALT_SENT;
     408          64 :             result = PG_SASL_EXCHANGE_CONTINUE;
     409          64 :             break;
     410             : 
     411          64 :         case SCRAM_AUTH_SALT_SENT:
     412             : 
     413             :             /*
     414             :              * Final phase for the server.  Receive the response to the
     415             :              * challenge previously sent, verify, and let the client know that
     416             :              * everything went well (or not).
     417             :              */
     418          64 :             read_client_final_message(state, input);
     419             : 
     420          64 :             if (!verify_final_nonce(state))
     421           0 :                 ereport(ERROR,
     422             :                         (errcode(ERRCODE_PROTOCOL_VIOLATION),
     423             :                          errmsg("invalid SCRAM response"),
     424             :                          errdetail("Nonce does not match.")));
     425             : 
     426             :             /*
     427             :              * Now check the final nonce and the client proof.
     428             :              *
     429             :              * If we performed a "mock" authentication that we knew would fail
     430             :              * from the get go, this is where we fail.
     431             :              *
     432             :              * The SCRAM specification includes an error code,
     433             :              * "invalid-proof", for authentication failure, but it also allows
     434             :              * erroring out in an application-specific way.  We choose to do
     435             :              * the latter, so that the error message for invalid password is
     436             :              * the same for all authentication methods.  The caller will call
     437             :              * ereport(), when we return PG_SASL_EXCHANGE_FAILURE with no
     438             :              * output.
     439             :              *
     440             :              * NB: the order of these checks is intentional.  We calculate the
     441             :              * client proof even in a mock authentication, even though it's
     442             :              * bound to fail, to thwart timing attacks to determine if a role
     443             :              * with the given name exists or not.
     444             :              */
     445          64 :             if (!verify_client_proof(state) || state->doomed)
     446             :             {
     447          10 :                 result = PG_SASL_EXCHANGE_FAILURE;
     448          10 :                 break;
     449             :             }
     450             : 
     451             :             /* Build final message for client */
     452          54 :             *output = build_server_final_message(state);
     453             : 
     454             :             /* Success! */
     455          54 :             result = PG_SASL_EXCHANGE_SUCCESS;
     456          54 :             state->state = SCRAM_AUTH_FINISHED;
     457          54 :             break;
     458             : 
     459           0 :         default:
     460           0 :             elog(ERROR, "invalid SCRAM exchange state");
     461             :             result = PG_SASL_EXCHANGE_FAILURE;
     462             :     }
     463             : 
     464         128 :     if (result == PG_SASL_EXCHANGE_FAILURE && state->logdetail && logdetail)
     465           2 :         *logdetail = state->logdetail;
     466             : 
     467         128 :     if (*output)
     468         118 :         *outputlen = strlen(*output);
     469             : 
     470         128 :     return result;
     471             : }
     472             : 
     473             : /*
     474             :  * Construct a SCRAM secret, for storing in pg_authid.rolpassword.
     475             :  *
     476             :  * The result is palloc'd, so caller is responsible for freeing it.
     477             :  */
     478             : char *
     479          90 : pg_be_scram_build_secret(const char *password)
     480             : {
     481             :     char       *prep_password;
     482             :     pg_saslprep_rc rc;
     483             :     char        saltbuf[SCRAM_DEFAULT_SALT_LEN];
     484             :     char       *result;
     485          90 :     const char *errstr = NULL;
     486             : 
     487             :     /*
     488             :      * Normalize the password with SASLprep.  If that doesn't work, because
     489             :      * the password isn't valid UTF-8 or contains prohibited characters, just
     490             :      * proceed with the original password.  (See comments at top of file.)
     491             :      */
     492          90 :     rc = pg_saslprep(password, &prep_password);
     493          90 :     if (rc == SASLPREP_SUCCESS)
     494          88 :         password = (const char *) prep_password;
     495             : 
     496             :     /* Generate random salt */
     497          90 :     if (!pg_strong_random(saltbuf, SCRAM_DEFAULT_SALT_LEN))
     498           0 :         ereport(ERROR,
     499             :                 (errcode(ERRCODE_INTERNAL_ERROR),
     500             :                  errmsg("could not generate random salt")));
     501             : 
     502          90 :     result = scram_build_secret(PG_SHA256, SCRAM_SHA_256_KEY_LEN,
     503             :                                 saltbuf, SCRAM_DEFAULT_SALT_LEN,
     504             :                                 scram_sha_256_iterations, password,
     505             :                                 &errstr);
     506             : 
     507          90 :     if (prep_password)
     508          88 :         pfree(prep_password);
     509             : 
     510          90 :     return result;
     511             : }
     512             : 
     513             : /*
     514             :  * Verify a plaintext password against a SCRAM secret.  This is used when
     515             :  * performing plaintext password authentication for a user that has a SCRAM
     516             :  * secret stored in pg_authid.
     517             :  */
     518             : bool
     519          30 : scram_verify_plain_password(const char *username, const char *password,
     520             :                             const char *secret)
     521             : {
     522             :     char       *encoded_salt;
     523             :     char       *salt;
     524             :     int         saltlen;
     525             :     int         iterations;
     526          30 :     int         key_length = 0;
     527             :     pg_cryptohash_type hash_type;
     528             :     uint8       salted_password[SCRAM_MAX_KEY_LEN];
     529             :     uint8       stored_key[SCRAM_MAX_KEY_LEN];
     530             :     uint8       server_key[SCRAM_MAX_KEY_LEN];
     531             :     uint8       computed_key[SCRAM_MAX_KEY_LEN];
     532             :     char       *prep_password;
     533             :     pg_saslprep_rc rc;
     534          30 :     const char *errstr = NULL;
     535             : 
     536          30 :     if (!parse_scram_secret(secret, &iterations, &hash_type, &key_length,
     537             :                             &encoded_salt, stored_key, server_key))
     538             :     {
     539             :         /*
     540             :          * The password looked like a SCRAM secret, but could not be parsed.
     541             :          */
     542           0 :         ereport(LOG,
     543             :                 (errmsg("invalid SCRAM secret for user \"%s\"", username)));
     544           0 :         return false;
     545             :     }
     546             : 
     547          30 :     saltlen = pg_b64_dec_len(strlen(encoded_salt));
     548          30 :     salt = palloc(saltlen);
     549          30 :     saltlen = pg_b64_decode(encoded_salt, strlen(encoded_salt), salt,
     550             :                             saltlen);
     551          30 :     if (saltlen < 0)
     552             :     {
     553           0 :         ereport(LOG,
     554             :                 (errmsg("invalid SCRAM secret for user \"%s\"", username)));
     555           0 :         return false;
     556             :     }
     557             : 
     558             :     /* Normalize the password */
     559          30 :     rc = pg_saslprep(password, &prep_password);
     560          30 :     if (rc == SASLPREP_SUCCESS)
     561          30 :         password = prep_password;
     562             : 
     563             :     /* Compute Server Key based on the user-supplied plaintext password */
     564          30 :     if (scram_SaltedPassword(password, hash_type, key_length,
     565             :                              salt, saltlen, iterations,
     566          30 :                              salted_password, &errstr) < 0 ||
     567          30 :         scram_ServerKey(salted_password, hash_type, key_length,
     568             :                         computed_key, &errstr) < 0)
     569             :     {
     570           0 :         elog(ERROR, "could not compute server key: %s", errstr);
     571             :     }
     572             : 
     573          30 :     if (prep_password)
     574          30 :         pfree(prep_password);
     575             : 
     576             :     /*
     577             :      * Compare the secret's Server Key with the one computed from the
     578             :      * user-supplied password.
     579             :      */
     580          30 :     return memcmp(computed_key, server_key, key_length) == 0;
     581             : }
     582             : 
     583             : 
     584             : /*
     585             :  * Parse and validate format of given SCRAM secret.
     586             :  *
     587             :  * On success, the iteration count, salt, stored key, and server key are
     588             :  * extracted from the secret, and returned to the caller.  For 'stored_key'
     589             :  * and 'server_key', the caller must pass pre-allocated buffers of size
     590             :  * SCRAM_MAX_KEY_LEN.  Salt is returned as a base64-encoded, null-terminated
     591             :  * string.  The buffer for the salt is palloc'd by this function.
     592             :  *
     593             :  * Returns true if the SCRAM secret has been parsed, and false otherwise.
     594             :  */
     595             : bool
     596         508 : parse_scram_secret(const char *secret, int *iterations,
     597             :                    pg_cryptohash_type *hash_type, int *key_length,
     598             :                    char **salt, uint8 *stored_key, uint8 *server_key)
     599             : {
     600             :     char       *v;
     601             :     char       *p;
     602             :     char       *scheme_str;
     603             :     char       *salt_str;
     604             :     char       *iterations_str;
     605             :     char       *storedkey_str;
     606             :     char       *serverkey_str;
     607             :     int         decoded_len;
     608             :     char       *decoded_salt_buf;
     609             :     char       *decoded_stored_buf;
     610             :     char       *decoded_server_buf;
     611             : 
     612             :     /*
     613             :      * The secret is of form:
     614             :      *
     615             :      * SCRAM-SHA-256$<iterations>:<salt>$<storedkey>:<serverkey>
     616             :      */
     617         508 :     v = pstrdup(secret);
     618         508 :     if ((scheme_str = strtok(v, "$")) == NULL)
     619           0 :         goto invalid_secret;
     620         508 :     if ((iterations_str = strtok(NULL, ":")) == NULL)
     621         200 :         goto invalid_secret;
     622         308 :     if ((salt_str = strtok(NULL, "$")) == NULL)
     623          12 :         goto invalid_secret;
     624         296 :     if ((storedkey_str = strtok(NULL, ":")) == NULL)
     625           0 :         goto invalid_secret;
     626         296 :     if ((serverkey_str = strtok(NULL, "")) == NULL)
     627           0 :         goto invalid_secret;
     628             : 
     629             :     /* Parse the fields */
     630         296 :     if (strcmp(scheme_str, "SCRAM-SHA-256") != 0)
     631           0 :         goto invalid_secret;
     632         296 :     *hash_type = PG_SHA256;
     633         296 :     *key_length = SCRAM_SHA_256_KEY_LEN;
     634             : 
     635         296 :     errno = 0;
     636         296 :     *iterations = strtol(iterations_str, &p, 10);
     637         296 :     if (*p || errno != 0)
     638           0 :         goto invalid_secret;
     639             : 
     640             :     /*
     641             :      * Verify that the salt is in Base64-encoded format, by decoding it,
     642             :      * although we return the encoded version to the caller.
     643             :      */
     644         296 :     decoded_len = pg_b64_dec_len(strlen(salt_str));
     645         296 :     decoded_salt_buf = palloc(decoded_len);
     646         296 :     decoded_len = pg_b64_decode(salt_str, strlen(salt_str),
     647             :                                 decoded_salt_buf, decoded_len);
     648         296 :     if (decoded_len < 0)
     649           0 :         goto invalid_secret;
     650         296 :     *salt = pstrdup(salt_str);
     651             : 
     652             :     /*
     653             :      * Decode StoredKey and ServerKey.
     654             :      */
     655         296 :     decoded_len = pg_b64_dec_len(strlen(storedkey_str));
     656         296 :     decoded_stored_buf = palloc(decoded_len);
     657         296 :     decoded_len = pg_b64_decode(storedkey_str, strlen(storedkey_str),
     658             :                                 decoded_stored_buf, decoded_len);
     659         296 :     if (decoded_len != *key_length)
     660          12 :         goto invalid_secret;
     661         284 :     memcpy(stored_key, decoded_stored_buf, *key_length);
     662             : 
     663         284 :     decoded_len = pg_b64_dec_len(strlen(serverkey_str));
     664         284 :     decoded_server_buf = palloc(decoded_len);
     665         284 :     decoded_len = pg_b64_decode(serverkey_str, strlen(serverkey_str),
     666             :                                 decoded_server_buf, decoded_len);
     667         284 :     if (decoded_len != *key_length)
     668          12 :         goto invalid_secret;
     669         272 :     memcpy(server_key, decoded_server_buf, *key_length);
     670             : 
     671         272 :     return true;
     672             : 
     673         236 : invalid_secret:
     674         236 :     *salt = NULL;
     675         236 :     return false;
     676             : }
     677             : 
     678             : /*
     679             :  * Generate plausible SCRAM secret parameters for mock authentication.
     680             :  *
     681             :  * In a normal authentication, these are extracted from the secret
     682             :  * stored in the server.  This function generates values that look
     683             :  * realistic, for when there is no stored secret, using SCRAM-SHA-256.
     684             :  *
     685             :  * Like in parse_scram_secret(), for 'stored_key' and 'server_key', the
     686             :  * caller must pass pre-allocated buffers of size SCRAM_MAX_KEY_LEN, and
     687             :  * the buffer for the salt is palloc'd by this function.
     688             :  */
     689             : static void
     690           2 : mock_scram_secret(const char *username, pg_cryptohash_type *hash_type,
     691             :                   int *iterations, int *key_length, char **salt,
     692             :                   uint8 *stored_key, uint8 *server_key)
     693             : {
     694             :     char       *raw_salt;
     695             :     char       *encoded_salt;
     696             :     int         encoded_len;
     697             : 
     698             :     /* Enforce the use of SHA-256, which would be realistic enough */
     699           2 :     *hash_type = PG_SHA256;
     700           2 :     *key_length = SCRAM_SHA_256_KEY_LEN;
     701             : 
     702             :     /*
     703             :      * Generate deterministic salt.
     704             :      *
     705             :      * Note that we cannot reveal any information to an attacker here so the
     706             :      * error messages need to remain generic.  This should never fail anyway
     707             :      * as the salt generated for mock authentication uses the cluster's nonce
     708             :      * value.
     709             :      */
     710           2 :     raw_salt = scram_mock_salt(username, *hash_type, *key_length);
     711           2 :     if (raw_salt == NULL)
     712           0 :         elog(ERROR, "could not encode salt");
     713             : 
     714           2 :     encoded_len = pg_b64_enc_len(SCRAM_DEFAULT_SALT_LEN);
     715             :     /* don't forget the zero-terminator */
     716           2 :     encoded_salt = (char *) palloc(encoded_len + 1);
     717           2 :     encoded_len = pg_b64_encode(raw_salt, SCRAM_DEFAULT_SALT_LEN, encoded_salt,
     718             :                                 encoded_len);
     719             : 
     720           2 :     if (encoded_len < 0)
     721           0 :         elog(ERROR, "could not encode salt");
     722           2 :     encoded_salt[encoded_len] = '\0';
     723             : 
     724           2 :     *salt = encoded_salt;
     725           2 :     *iterations = SCRAM_SHA_256_DEFAULT_ITERATIONS;
     726             : 
     727             :     /* StoredKey and ServerKey are not used in a doomed authentication */
     728           2 :     memset(stored_key, 0, SCRAM_MAX_KEY_LEN);
     729           2 :     memset(server_key, 0, SCRAM_MAX_KEY_LEN);
     730           2 : }
     731             : 
     732             : /*
     733             :  * Read the value in a given SCRAM exchange message for given attribute.
     734             :  */
     735             : static char *
     736         264 : read_attr_value(char **input, char attr)
     737             : {
     738         264 :     char       *begin = *input;
     739             :     char       *end;
     740             : 
     741         264 :     if (*begin != attr)
     742           0 :         ereport(ERROR,
     743             :                 (errcode(ERRCODE_PROTOCOL_VIOLATION),
     744             :                  errmsg("malformed SCRAM message"),
     745             :                  errdetail("Expected attribute \"%c\" but found \"%s\".",
     746             :                            attr, sanitize_char(*begin))));
     747         264 :     begin++;
     748             : 
     749         264 :     if (*begin != '=')
     750           0 :         ereport(ERROR,
     751             :                 (errcode(ERRCODE_PROTOCOL_VIOLATION),
     752             :                  errmsg("malformed SCRAM message"),
     753             :                  errdetail("Expected character \"=\" for attribute \"%c\".", attr)));
     754         264 :     begin++;
     755             : 
     756         264 :     end = begin;
     757        5864 :     while (*end && *end != ',')
     758        5600 :         end++;
     759             : 
     760         264 :     if (*end)
     761             :     {
     762         200 :         *end = '\0';
     763         200 :         *input = end + 1;
     764             :     }
     765             :     else
     766          64 :         *input = end;
     767             : 
     768         264 :     return begin;
     769             : }
     770             : 
     771             : static bool
     772          64 : is_scram_printable(char *p)
     773             : {
     774             :     /*------
     775             :      * Printable characters, as defined by SCRAM spec: (RFC 5802)
     776             :      *
     777             :      *  printable       = %x21-2B / %x2D-7E
     778             :      *                    ;; Printable ASCII except ",".
     779             :      *                    ;; Note that any "printable" is also
     780             :      *                    ;; a valid "value".
     781             :      *------
     782             :      */
     783        1600 :     for (; *p; p++)
     784             :     {
     785        1536 :         if (*p < 0x21 || *p > 0x7E || *p == 0x2C /* comma */ )
     786           0 :             return false;
     787             :     }
     788          64 :     return true;
     789             : }
     790             : 
     791             : /*
     792             :  * Convert an arbitrary byte to printable form.  For error messages.
     793             :  *
     794             :  * If it's a printable ASCII character, print it as a single character.
     795             :  * otherwise, print it in hex.
     796             :  *
     797             :  * The returned pointer points to a static buffer.
     798             :  */
     799             : static char *
     800           0 : sanitize_char(char c)
     801             : {
     802             :     static char buf[5];
     803             : 
     804           0 :     if (c >= 0x21 && c <= 0x7E)
     805           0 :         snprintf(buf, sizeof(buf), "'%c'", c);
     806             :     else
     807           0 :         snprintf(buf, sizeof(buf), "0x%02x", (unsigned char) c);
     808           0 :     return buf;
     809             : }
     810             : 
     811             : /*
     812             :  * Convert an arbitrary string to printable form, for error messages.
     813             :  *
     814             :  * Anything that's not a printable ASCII character is replaced with
     815             :  * '?', and the string is truncated at 30 characters.
     816             :  *
     817             :  * The returned pointer points to a static buffer.
     818             :  */
     819             : static char *
     820           0 : sanitize_str(const char *s)
     821             : {
     822             :     static char buf[30 + 1];
     823             :     int         i;
     824             : 
     825           0 :     for (i = 0; i < sizeof(buf) - 1; i++)
     826             :     {
     827           0 :         char        c = s[i];
     828             : 
     829           0 :         if (c == '\0')
     830           0 :             break;
     831             : 
     832           0 :         if (c >= 0x21 && c <= 0x7E)
     833           0 :             buf[i] = c;
     834             :         else
     835           0 :             buf[i] = '?';
     836             :     }
     837           0 :     buf[i] = '\0';
     838           0 :     return buf;
     839             : }
     840             : 
     841             : /*
     842             :  * Read the next attribute and value in a SCRAM exchange message.
     843             :  *
     844             :  * The attribute character is set in *attr_p, the attribute value is the
     845             :  * return value.
     846             :  */
     847             : static char *
     848          64 : read_any_attr(char **input, char *attr_p)
     849             : {
     850          64 :     char       *begin = *input;
     851             :     char       *end;
     852          64 :     char        attr = *begin;
     853             : 
     854          64 :     if (attr == '\0')
     855           0 :         ereport(ERROR,
     856             :                 (errcode(ERRCODE_PROTOCOL_VIOLATION),
     857             :                  errmsg("malformed SCRAM message"),
     858             :                  errdetail("Attribute expected, but found end of string.")));
     859             : 
     860             :     /*------
     861             :      * attr-val        = ALPHA "=" value
     862             :      *                   ;; Generic syntax of any attribute sent
     863             :      *                   ;; by server or client
     864             :      *------
     865             :      */
     866          64 :     if (!((attr >= 'A' && attr <= 'Z') ||
     867          64 :           (attr >= 'a' && attr <= 'z')))
     868           0 :         ereport(ERROR,
     869             :                 (errcode(ERRCODE_PROTOCOL_VIOLATION),
     870             :                  errmsg("malformed SCRAM message"),
     871             :                  errdetail("Attribute expected, but found invalid character \"%s\".",
     872             :                            sanitize_char(attr))));
     873          64 :     if (attr_p)
     874          64 :         *attr_p = attr;
     875          64 :     begin++;
     876             : 
     877          64 :     if (*begin != '=')
     878           0 :         ereport(ERROR,
     879             :                 (errcode(ERRCODE_PROTOCOL_VIOLATION),
     880             :                  errmsg("malformed SCRAM message"),
     881             :                  errdetail("Expected character \"=\" for attribute \"%c\".", attr)));
     882          64 :     begin++;
     883             : 
     884          64 :     end = begin;
     885        2880 :     while (*end && *end != ',')
     886        2816 :         end++;
     887             : 
     888          64 :     if (*end)
     889             :     {
     890           0 :         *end = '\0';
     891           0 :         *input = end + 1;
     892             :     }
     893             :     else
     894          64 :         *input = end;
     895             : 
     896          64 :     return begin;
     897             : }
     898             : 
     899             : /*
     900             :  * Read and parse the first message from client in the context of a SCRAM
     901             :  * authentication exchange message.
     902             :  *
     903             :  * At this stage, any errors will be reported directly with ereport(ERROR).
     904             :  */
     905             : static void
     906          64 : read_client_first_message(scram_state *state, const char *input)
     907             : {
     908          64 :     char       *p = pstrdup(input);
     909             :     char       *channel_binding_type;
     910             : 
     911             : 
     912             :     /*------
     913             :      * The syntax for the client-first-message is: (RFC 5802)
     914             :      *
     915             :      * saslname        = 1*(value-safe-char / "=2C" / "=3D")
     916             :      *                   ;; Conforms to <value>.
     917             :      *
     918             :      * authzid         = "a=" saslname
     919             :      *                   ;; Protocol specific.
     920             :      *
     921             :      * cb-name         = 1*(ALPHA / DIGIT / "." / "-")
     922             :      *                    ;; See RFC 5056, Section 7.
     923             :      *                    ;; E.g., "tls-server-end-point" or
     924             :      *                    ;; "tls-unique".
     925             :      *
     926             :      * gs2-cbind-flag  = ("p=" cb-name) / "n" / "y"
     927             :      *                   ;; "n" -> client doesn't support channel binding.
     928             :      *                   ;; "y" -> client does support channel binding
     929             :      *                   ;;        but thinks the server does not.
     930             :      *                   ;; "p" -> client requires channel binding.
     931             :      *                   ;; The selected channel binding follows "p=".
     932             :      *
     933             :      * gs2-header      = gs2-cbind-flag "," [ authzid ] ","
     934             :      *                   ;; GS2 header for SCRAM
     935             :      *                   ;; (the actual GS2 header includes an optional
     936             :      *                   ;; flag to indicate that the GSS mechanism is not
     937             :      *                   ;; "standard", but since SCRAM is "standard", we
     938             :      *                   ;; don't include that flag).
     939             :      *
     940             :      * username        = "n=" saslname
     941             :      *                   ;; Usernames are prepared using SASLprep.
     942             :      *
     943             :      * reserved-mext  = "m=" 1*(value-char)
     944             :      *                   ;; Reserved for signaling mandatory extensions.
     945             :      *                   ;; The exact syntax will be defined in
     946             :      *                   ;; the future.
     947             :      *
     948             :      * nonce           = "r=" c-nonce [s-nonce]
     949             :      *                   ;; Second part provided by server.
     950             :      *
     951             :      * c-nonce         = printable
     952             :      *
     953             :      * client-first-message-bare =
     954             :      *                   [reserved-mext ","]
     955             :      *                   username "," nonce ["," extensions]
     956             :      *
     957             :      * client-first-message =
     958             :      *                   gs2-header client-first-message-bare
     959             :      *
     960             :      * For example:
     961             :      * n,,n=user,r=fyko+d2lbbFgONRv9qkxdawL
     962             :      *
     963             :      * The "n,," in the beginning means that the client doesn't support
     964             :      * channel binding, and no authzid is given.  "n=user" is the username.
     965             :      * However, in PostgreSQL the username is sent in the startup packet, and
     966             :      * the username in the SCRAM exchange is ignored.  libpq always sends it
     967             :      * as an empty string.  The last part, "r=fyko+d2lbbFgONRv9qkxdawL" is
     968             :      * the client nonce.
     969             :      *------
     970             :      */
     971             : 
     972             :     /*
     973             :      * Read gs2-cbind-flag.  (For details see also RFC 5802 Section 6 "Channel
     974             :      * Binding".)
     975             :      */
     976          64 :     state->cbind_flag = *p;
     977          64 :     switch (*p)
     978             :     {
     979          56 :         case 'n':
     980             : 
     981             :             /*
     982             :              * The client does not support channel binding or has simply
     983             :              * decided to not use it.  In that case just let it go.
     984             :              */
     985          56 :             if (state->channel_binding_in_use)
     986           0 :                 ereport(ERROR,
     987             :                         (errcode(ERRCODE_PROTOCOL_VIOLATION),
     988             :                          errmsg("malformed SCRAM message"),
     989             :                          errdetail("The client selected SCRAM-SHA-256-PLUS, but the SCRAM message does not include channel binding data.")));
     990             : 
     991          56 :             p++;
     992          56 :             if (*p != ',')
     993           0 :                 ereport(ERROR,
     994             :                         (errcode(ERRCODE_PROTOCOL_VIOLATION),
     995             :                          errmsg("malformed SCRAM message"),
     996             :                          errdetail("Comma expected, but found character \"%s\".",
     997             :                                    sanitize_char(*p))));
     998          56 :             p++;
     999          56 :             break;
    1000           0 :         case 'y':
    1001             : 
    1002             :             /*
    1003             :              * The client supports channel binding and thinks that the server
    1004             :              * does not.  In this case, the server must fail authentication if
    1005             :              * it supports channel binding.
    1006             :              */
    1007           0 :             if (state->channel_binding_in_use)
    1008           0 :                 ereport(ERROR,
    1009             :                         (errcode(ERRCODE_PROTOCOL_VIOLATION),
    1010             :                          errmsg("malformed SCRAM message"),
    1011             :                          errdetail("The client selected SCRAM-SHA-256-PLUS, but the SCRAM message does not include channel binding data.")));
    1012             : 
    1013             : #ifdef HAVE_BE_TLS_GET_CERTIFICATE_HASH
    1014           0 :             if (state->port->ssl_in_use)
    1015           0 :                 ereport(ERROR,
    1016             :                         (errcode(ERRCODE_INVALID_AUTHORIZATION_SPECIFICATION),
    1017             :                          errmsg("SCRAM channel binding negotiation error"),
    1018             :                          errdetail("The client supports SCRAM channel binding but thinks the server does not.  "
    1019             :                                    "However, this server does support channel binding.")));
    1020             : #endif
    1021           0 :             p++;
    1022           0 :             if (*p != ',')
    1023           0 :                 ereport(ERROR,
    1024             :                         (errcode(ERRCODE_PROTOCOL_VIOLATION),
    1025             :                          errmsg("malformed SCRAM message"),
    1026             :                          errdetail("Comma expected, but found character \"%s\".",
    1027             :                                    sanitize_char(*p))));
    1028           0 :             p++;
    1029           0 :             break;
    1030           8 :         case 'p':
    1031             : 
    1032             :             /*
    1033             :              * The client requires channel binding.  Channel binding type
    1034             :              * follows, e.g., "p=tls-server-end-point".
    1035             :              */
    1036           8 :             if (!state->channel_binding_in_use)
    1037           0 :                 ereport(ERROR,
    1038             :                         (errcode(ERRCODE_PROTOCOL_VIOLATION),
    1039             :                          errmsg("malformed SCRAM message"),
    1040             :                          errdetail("The client selected SCRAM-SHA-256 without channel binding, but the SCRAM message includes channel binding data.")));
    1041             : 
    1042           8 :             channel_binding_type = read_attr_value(&p, 'p');
    1043             : 
    1044             :             /*
    1045             :              * The only channel binding type we support is
    1046             :              * tls-server-end-point.
    1047             :              */
    1048           8 :             if (strcmp(channel_binding_type, "tls-server-end-point") != 0)
    1049           0 :                 ereport(ERROR,
    1050             :                         (errcode(ERRCODE_PROTOCOL_VIOLATION),
    1051             :                          errmsg("unsupported SCRAM channel-binding type \"%s\"",
    1052             :                                 sanitize_str(channel_binding_type))));
    1053           8 :             break;
    1054           0 :         default:
    1055           0 :             ereport(ERROR,
    1056             :                     (errcode(ERRCODE_PROTOCOL_VIOLATION),
    1057             :                      errmsg("malformed SCRAM message"),
    1058             :                      errdetail("Unexpected channel-binding flag \"%s\".",
    1059             :                                sanitize_char(*p))));
    1060             :     }
    1061             : 
    1062             :     /*
    1063             :      * Forbid optional authzid (authorization identity).  We don't support it.
    1064             :      */
    1065          64 :     if (*p == 'a')
    1066           0 :         ereport(ERROR,
    1067             :                 (errcode(ERRCODE_FEATURE_NOT_SUPPORTED),
    1068             :                  errmsg("client uses authorization identity, but it is not supported")));
    1069          64 :     if (*p != ',')
    1070           0 :         ereport(ERROR,
    1071             :                 (errcode(ERRCODE_PROTOCOL_VIOLATION),
    1072             :                  errmsg("malformed SCRAM message"),
    1073             :                  errdetail("Unexpected attribute \"%s\" in client-first-message.",
    1074             :                            sanitize_char(*p))));
    1075          64 :     p++;
    1076             : 
    1077          64 :     state->client_first_message_bare = pstrdup(p);
    1078             : 
    1079             :     /*
    1080             :      * Any mandatory extensions would go here.  We don't support any.
    1081             :      *
    1082             :      * RFC 5802 specifies error code "e=extensions-not-supported" for this,
    1083             :      * but it can only be sent in the server-final message.  We prefer to fail
    1084             :      * immediately (which the RFC also allows).
    1085             :      */
    1086          64 :     if (*p == 'm')
    1087           0 :         ereport(ERROR,
    1088             :                 (errcode(ERRCODE_FEATURE_NOT_SUPPORTED),
    1089             :                  errmsg("client requires an unsupported SCRAM extension")));
    1090             : 
    1091             :     /*
    1092             :      * Read username.  Note: this is ignored.  We use the username from the
    1093             :      * startup message instead, still it is kept around if provided as it
    1094             :      * proves to be useful for debugging purposes.
    1095             :      */
    1096          64 :     state->client_username = read_attr_value(&p, 'n');
    1097             : 
    1098             :     /* read nonce and check that it is made of only printable characters */
    1099          64 :     state->client_nonce = read_attr_value(&p, 'r');
    1100          64 :     if (!is_scram_printable(state->client_nonce))
    1101           0 :         ereport(ERROR,
    1102             :                 (errcode(ERRCODE_PROTOCOL_VIOLATION),
    1103             :                  errmsg("non-printable characters in SCRAM nonce")));
    1104             : 
    1105             :     /*
    1106             :      * There can be any number of optional extensions after this.  We don't
    1107             :      * support any extensions, so ignore them.
    1108             :      */
    1109          64 :     while (*p != '\0')
    1110           0 :         read_any_attr(&p, NULL);
    1111             : 
    1112             :     /* success! */
    1113          64 : }
    1114             : 
    1115             : /*
    1116             :  * Verify the final nonce contained in the last message received from
    1117             :  * client in an exchange.
    1118             :  */
    1119             : static bool
    1120          64 : verify_final_nonce(scram_state *state)
    1121             : {
    1122          64 :     int         client_nonce_len = strlen(state->client_nonce);
    1123          64 :     int         server_nonce_len = strlen(state->server_nonce);
    1124          64 :     int         final_nonce_len = strlen(state->client_final_nonce);
    1125             : 
    1126          64 :     if (final_nonce_len != client_nonce_len + server_nonce_len)
    1127           0 :         return false;
    1128          64 :     if (memcmp(state->client_final_nonce, state->client_nonce, client_nonce_len) != 0)
    1129           0 :         return false;
    1130          64 :     if (memcmp(state->client_final_nonce + client_nonce_len, state->server_nonce, server_nonce_len) != 0)
    1131           0 :         return false;
    1132             : 
    1133          64 :     return true;
    1134             : }
    1135             : 
    1136             : /*
    1137             :  * Verify the client proof contained in the last message received from
    1138             :  * client in an exchange.  Returns true if the verification is a success,
    1139             :  * or false for a failure.
    1140             :  */
    1141             : static bool
    1142          64 : verify_client_proof(scram_state *state)
    1143             : {
    1144             :     uint8       ClientSignature[SCRAM_MAX_KEY_LEN];
    1145             :     uint8       ClientKey[SCRAM_MAX_KEY_LEN];
    1146             :     uint8       client_StoredKey[SCRAM_MAX_KEY_LEN];
    1147          64 :     pg_hmac_ctx *ctx = pg_hmac_create(state->hash_type);
    1148             :     int         i;
    1149          64 :     const char *errstr = NULL;
    1150             : 
    1151             :     /*
    1152             :      * Calculate ClientSignature.  Note that we don't log directly a failure
    1153             :      * here even when processing the calculations as this could involve a mock
    1154             :      * authentication.
    1155             :      */
    1156         128 :     if (pg_hmac_init(ctx, state->StoredKey, state->key_length) < 0 ||
    1157          64 :         pg_hmac_update(ctx,
    1158          64 :                        (uint8 *) state->client_first_message_bare,
    1159         128 :                        strlen(state->client_first_message_bare)) < 0 ||
    1160         128 :         pg_hmac_update(ctx, (uint8 *) ",", 1) < 0 ||
    1161          64 :         pg_hmac_update(ctx,
    1162          64 :                        (uint8 *) state->server_first_message,
    1163         128 :                        strlen(state->server_first_message)) < 0 ||
    1164         128 :         pg_hmac_update(ctx, (uint8 *) ",", 1) < 0 ||
    1165          64 :         pg_hmac_update(ctx,
    1166          64 :                        (uint8 *) state->client_final_message_without_proof,
    1167         128 :                        strlen(state->client_final_message_without_proof)) < 0 ||
    1168          64 :         pg_hmac_final(ctx, ClientSignature, state->key_length) < 0)
    1169             :     {
    1170           0 :         elog(ERROR, "could not calculate client signature: %s",
    1171             :              pg_hmac_error(ctx));
    1172             :     }
    1173             : 
    1174          64 :     pg_hmac_free(ctx);
    1175             : 
    1176             :     /* Extract the ClientKey that the client calculated from the proof */
    1177        2112 :     for (i = 0; i < state->key_length; i++)
    1178        2048 :         ClientKey[i] = state->ClientProof[i] ^ ClientSignature[i];
    1179             : 
    1180             :     /* Hash it one more time, and compare with StoredKey */
    1181          64 :     if (scram_H(ClientKey, state->hash_type, state->key_length,
    1182             :                 client_StoredKey, &errstr) < 0)
    1183           0 :         elog(ERROR, "could not hash stored key: %s", errstr);
    1184             : 
    1185          64 :     if (memcmp(client_StoredKey, state->StoredKey, state->key_length) != 0)
    1186          10 :         return false;
    1187             : 
    1188          54 :     return true;
    1189             : }
    1190             : 
    1191             : /*
    1192             :  * Build the first server-side message sent to the client in a SCRAM
    1193             :  * communication exchange.
    1194             :  */
    1195             : static char *
    1196          64 : build_server_first_message(scram_state *state)
    1197             : {
    1198             :     /*------
    1199             :      * The syntax for the server-first-message is: (RFC 5802)
    1200             :      *
    1201             :      * server-first-message =
    1202             :      *                   [reserved-mext ","] nonce "," salt ","
    1203             :      *                   iteration-count ["," extensions]
    1204             :      *
    1205             :      * nonce           = "r=" c-nonce [s-nonce]
    1206             :      *                   ;; Second part provided by server.
    1207             :      *
    1208             :      * c-nonce         = printable
    1209             :      *
    1210             :      * s-nonce         = printable
    1211             :      *
    1212             :      * salt            = "s=" base64
    1213             :      *
    1214             :      * iteration-count = "i=" posit-number
    1215             :      *                   ;; A positive number.
    1216             :      *
    1217             :      * Example:
    1218             :      *
    1219             :      * r=fyko+d2lbbFgONRv9qkxdawL3rfcNHYJY1ZVvWVs7j,s=QSXCR+Q6sek8bf92,i=4096
    1220             :      *------
    1221             :      */
    1222             : 
    1223             :     /*
    1224             :      * Per the spec, the nonce may consist of any printable ASCII characters.
    1225             :      * For convenience, however, we don't use the whole range available,
    1226             :      * rather, we generate some random bytes, and base64 encode them.
    1227             :      */
    1228             :     char        raw_nonce[SCRAM_RAW_NONCE_LEN];
    1229             :     int         encoded_len;
    1230             : 
    1231          64 :     if (!pg_strong_random(raw_nonce, SCRAM_RAW_NONCE_LEN))
    1232           0 :         ereport(ERROR,
    1233             :                 (errcode(ERRCODE_INTERNAL_ERROR),
    1234             :                  errmsg("could not generate random nonce")));
    1235             : 
    1236          64 :     encoded_len = pg_b64_enc_len(SCRAM_RAW_NONCE_LEN);
    1237             :     /* don't forget the zero-terminator */
    1238          64 :     state->server_nonce = palloc(encoded_len + 1);
    1239          64 :     encoded_len = pg_b64_encode(raw_nonce, SCRAM_RAW_NONCE_LEN,
    1240             :                                 state->server_nonce, encoded_len);
    1241          64 :     if (encoded_len < 0)
    1242           0 :         ereport(ERROR,
    1243             :                 (errcode(ERRCODE_INTERNAL_ERROR),
    1244             :                  errmsg("could not encode random nonce")));
    1245          64 :     state->server_nonce[encoded_len] = '\0';
    1246             : 
    1247          64 :     state->server_first_message =
    1248          64 :         psprintf("r=%s%s,s=%s,i=%d",
    1249             :                  state->client_nonce, state->server_nonce,
    1250             :                  state->salt, state->iterations);
    1251             : 
    1252          64 :     return pstrdup(state->server_first_message);
    1253             : }
    1254             : 
    1255             : 
    1256             : /*
    1257             :  * Read and parse the final message received from client.
    1258             :  */
    1259             : static void
    1260          64 : read_client_final_message(scram_state *state, const char *input)
    1261             : {
    1262             :     char        attr;
    1263             :     char       *channel_binding;
    1264             :     char       *value;
    1265             :     char       *begin,
    1266             :                *proof;
    1267             :     char       *p;
    1268             :     char       *client_proof;
    1269             :     int         client_proof_len;
    1270             : 
    1271          64 :     begin = p = pstrdup(input);
    1272             : 
    1273             :     /*------
    1274             :      * The syntax for the server-first-message is: (RFC 5802)
    1275             :      *
    1276             :      * gs2-header      = gs2-cbind-flag "," [ authzid ] ","
    1277             :      *                   ;; GS2 header for SCRAM
    1278             :      *                   ;; (the actual GS2 header includes an optional
    1279             :      *                   ;; flag to indicate that the GSS mechanism is not
    1280             :      *                   ;; "standard", but since SCRAM is "standard", we
    1281             :      *                   ;; don't include that flag).
    1282             :      *
    1283             :      * cbind-input   = gs2-header [ cbind-data ]
    1284             :      *                   ;; cbind-data MUST be present for
    1285             :      *                   ;; gs2-cbind-flag of "p" and MUST be absent
    1286             :      *                   ;; for "y" or "n".
    1287             :      *
    1288             :      * channel-binding = "c=" base64
    1289             :      *                   ;; base64 encoding of cbind-input.
    1290             :      *
    1291             :      * proof           = "p=" base64
    1292             :      *
    1293             :      * client-final-message-without-proof =
    1294             :      *                   channel-binding "," nonce [","
    1295             :      *                   extensions]
    1296             :      *
    1297             :      * client-final-message =
    1298             :      *                   client-final-message-without-proof "," proof
    1299             :      *------
    1300             :      */
    1301             : 
    1302             :     /*
    1303             :      * Read channel binding.  This repeats the channel-binding flags and is
    1304             :      * then followed by the actual binding data depending on the type.
    1305             :      */
    1306          64 :     channel_binding = read_attr_value(&p, 'c');
    1307          64 :     if (state->channel_binding_in_use)
    1308             :     {
    1309             : #ifdef HAVE_BE_TLS_GET_CERTIFICATE_HASH
    1310           8 :         const char *cbind_data = NULL;
    1311           8 :         size_t      cbind_data_len = 0;
    1312             :         size_t      cbind_header_len;
    1313             :         char       *cbind_input;
    1314             :         size_t      cbind_input_len;
    1315             :         char       *b64_message;
    1316             :         int         b64_message_len;
    1317             : 
    1318             :         Assert(state->cbind_flag == 'p');
    1319             : 
    1320             :         /* Fetch hash data of server's SSL certificate */
    1321           8 :         cbind_data = be_tls_get_certificate_hash(state->port,
    1322             :                                                  &cbind_data_len);
    1323             : 
    1324             :         /* should not happen */
    1325           8 :         if (cbind_data == NULL || cbind_data_len == 0)
    1326           0 :             elog(ERROR, "could not get server certificate hash");
    1327             : 
    1328           8 :         cbind_header_len = strlen("p=tls-server-end-point,,");    /* p=type,, */
    1329           8 :         cbind_input_len = cbind_header_len + cbind_data_len;
    1330           8 :         cbind_input = palloc(cbind_input_len);
    1331           8 :         snprintf(cbind_input, cbind_input_len, "p=tls-server-end-point,,");
    1332           8 :         memcpy(cbind_input + cbind_header_len, cbind_data, cbind_data_len);
    1333             : 
    1334           8 :         b64_message_len = pg_b64_enc_len(cbind_input_len);
    1335             :         /* don't forget the zero-terminator */
    1336           8 :         b64_message = palloc(b64_message_len + 1);
    1337           8 :         b64_message_len = pg_b64_encode(cbind_input, cbind_input_len,
    1338             :                                         b64_message, b64_message_len);
    1339           8 :         if (b64_message_len < 0)
    1340           0 :             elog(ERROR, "could not encode channel binding data");
    1341           8 :         b64_message[b64_message_len] = '\0';
    1342             : 
    1343             :         /*
    1344             :          * Compare the value sent by the client with the value expected by the
    1345             :          * server.
    1346             :          */
    1347           8 :         if (strcmp(channel_binding, b64_message) != 0)
    1348           0 :             ereport(ERROR,
    1349             :                     (errcode(ERRCODE_INVALID_AUTHORIZATION_SPECIFICATION),
    1350             :                      errmsg("SCRAM channel binding check failed")));
    1351             : #else
    1352             :         /* shouldn't happen, because we checked this earlier already */
    1353             :         elog(ERROR, "channel binding not supported by this build");
    1354             : #endif
    1355             :     }
    1356             :     else
    1357             :     {
    1358             :         /*
    1359             :          * If we are not using channel binding, the binding data is expected
    1360             :          * to always be "biws", which is "n,," base64-encoded, or "eSws",
    1361             :          * which is "y,,".  We also have to check whether the flag is the same
    1362             :          * one that the client originally sent.
    1363             :          */
    1364          56 :         if (!(strcmp(channel_binding, "biws") == 0 && state->cbind_flag == 'n') &&
    1365           0 :             !(strcmp(channel_binding, "eSws") == 0 && state->cbind_flag == 'y'))
    1366           0 :             ereport(ERROR,
    1367             :                     (errcode(ERRCODE_PROTOCOL_VIOLATION),
    1368             :                      errmsg("unexpected SCRAM channel-binding attribute in client-final-message")));
    1369             :     }
    1370             : 
    1371          64 :     state->client_final_nonce = read_attr_value(&p, 'r');
    1372             : 
    1373             :     /* ignore optional extensions, read until we find "p" attribute */
    1374             :     do
    1375             :     {
    1376          64 :         proof = p - 1;
    1377          64 :         value = read_any_attr(&p, &attr);
    1378          64 :     } while (attr != 'p');
    1379             : 
    1380          64 :     client_proof_len = pg_b64_dec_len(strlen(value));
    1381          64 :     client_proof = palloc(client_proof_len);
    1382          64 :     if (pg_b64_decode(value, strlen(value), client_proof,
    1383          64 :                       client_proof_len) != state->key_length)
    1384           0 :         ereport(ERROR,
    1385             :                 (errcode(ERRCODE_PROTOCOL_VIOLATION),
    1386             :                  errmsg("malformed SCRAM message"),
    1387             :                  errdetail("Malformed proof in client-final-message.")));
    1388          64 :     memcpy(state->ClientProof, client_proof, state->key_length);
    1389          64 :     pfree(client_proof);
    1390             : 
    1391          64 :     if (*p != '\0')
    1392           0 :         ereport(ERROR,
    1393             :                 (errcode(ERRCODE_PROTOCOL_VIOLATION),
    1394             :                  errmsg("malformed SCRAM message"),
    1395             :                  errdetail("Garbage found at the end of client-final-message.")));
    1396             : 
    1397          64 :     state->client_final_message_without_proof = palloc(proof - begin + 1);
    1398          64 :     memcpy(state->client_final_message_without_proof, input, proof - begin);
    1399          64 :     state->client_final_message_without_proof[proof - begin] = '\0';
    1400          64 : }
    1401             : 
    1402             : /*
    1403             :  * Build the final server-side message of an exchange.
    1404             :  */
    1405             : static char *
    1406          54 : build_server_final_message(scram_state *state)
    1407             : {
    1408             :     uint8       ServerSignature[SCRAM_MAX_KEY_LEN];
    1409             :     char       *server_signature_base64;
    1410             :     int         siglen;
    1411          54 :     pg_hmac_ctx *ctx = pg_hmac_create(state->hash_type);
    1412             : 
    1413             :     /* calculate ServerSignature */
    1414         108 :     if (pg_hmac_init(ctx, state->ServerKey, state->key_length) < 0 ||
    1415          54 :         pg_hmac_update(ctx,
    1416          54 :                        (uint8 *) state->client_first_message_bare,
    1417         108 :                        strlen(state->client_first_message_bare)) < 0 ||
    1418         108 :         pg_hmac_update(ctx, (uint8 *) ",", 1) < 0 ||
    1419          54 :         pg_hmac_update(ctx,
    1420          54 :                        (uint8 *) state->server_first_message,
    1421         108 :                        strlen(state->server_first_message)) < 0 ||
    1422         108 :         pg_hmac_update(ctx, (uint8 *) ",", 1) < 0 ||
    1423          54 :         pg_hmac_update(ctx,
    1424          54 :                        (uint8 *) state->client_final_message_without_proof,
    1425         108 :                        strlen(state->client_final_message_without_proof)) < 0 ||
    1426          54 :         pg_hmac_final(ctx, ServerSignature, state->key_length) < 0)
    1427             :     {
    1428           0 :         elog(ERROR, "could not calculate server signature: %s",
    1429             :              pg_hmac_error(ctx));
    1430             :     }
    1431             : 
    1432          54 :     pg_hmac_free(ctx);
    1433             : 
    1434          54 :     siglen = pg_b64_enc_len(state->key_length);
    1435             :     /* don't forget the zero-terminator */
    1436          54 :     server_signature_base64 = palloc(siglen + 1);
    1437          54 :     siglen = pg_b64_encode((const char *) ServerSignature,
    1438             :                            state->key_length, server_signature_base64,
    1439             :                            siglen);
    1440          54 :     if (siglen < 0)
    1441           0 :         elog(ERROR, "could not encode server signature");
    1442          54 :     server_signature_base64[siglen] = '\0';
    1443             : 
    1444             :     /*------
    1445             :      * The syntax for the server-final-message is: (RFC 5802)
    1446             :      *
    1447             :      * verifier        = "v=" base64
    1448             :      *                   ;; base-64 encoded ServerSignature.
    1449             :      *
    1450             :      * server-final-message = (server-error / verifier)
    1451             :      *                   ["," extensions]
    1452             :      *
    1453             :      *------
    1454             :      */
    1455          54 :     return psprintf("v=%s", server_signature_base64);
    1456             : }
    1457             : 
    1458             : 
    1459             : /*
    1460             :  * Deterministically generate salt for mock authentication, using a SHA256
    1461             :  * hash based on the username and a cluster-level secret key.  Returns a
    1462             :  * pointer to a static buffer of size SCRAM_DEFAULT_SALT_LEN, or NULL.
    1463             :  */
    1464             : static char *
    1465           2 : scram_mock_salt(const char *username, pg_cryptohash_type hash_type,
    1466             :                 int key_length)
    1467             : {
    1468             :     pg_cryptohash_ctx *ctx;
    1469             :     static uint8 sha_digest[SCRAM_MAX_KEY_LEN];
    1470           2 :     char       *mock_auth_nonce = GetMockAuthenticationNonce();
    1471             : 
    1472             :     /*
    1473             :      * Generate salt using a SHA256 hash of the username and the cluster's
    1474             :      * mock authentication nonce.  (This works as long as the salt length is
    1475             :      * not larger than the SHA256 digest length.  If the salt is smaller, the
    1476             :      * caller will just ignore the extra data.)
    1477             :      */
    1478             :     StaticAssertDecl(PG_SHA256_DIGEST_LENGTH >= SCRAM_DEFAULT_SALT_LEN,
    1479             :                      "salt length greater than SHA256 digest length");
    1480             : 
    1481             :     /*
    1482             :      * This may be worth refreshing if support for more hash methods is\
    1483             :      * added.
    1484             :      */
    1485             :     Assert(hash_type == PG_SHA256);
    1486             : 
    1487           2 :     ctx = pg_cryptohash_create(hash_type);
    1488           4 :     if (pg_cryptohash_init(ctx) < 0 ||
    1489           4 :         pg_cryptohash_update(ctx, (uint8 *) username, strlen(username)) < 0 ||
    1490           4 :         pg_cryptohash_update(ctx, (uint8 *) mock_auth_nonce, MOCK_AUTH_NONCE_LEN) < 0 ||
    1491           2 :         pg_cryptohash_final(ctx, sha_digest, key_length) < 0)
    1492             :     {
    1493           0 :         pg_cryptohash_free(ctx);
    1494           0 :         return NULL;
    1495             :     }
    1496           2 :     pg_cryptohash_free(ctx);
    1497             : 
    1498           2 :     return (char *) sha_digest;
    1499             : }

Generated by: LCOV version 1.14