LCOV - code coverage report
Current view: top level - src/backend/libpq - auth-scram.c (source / functions) Coverage Total Hit
Test: PostgreSQL 19devel Lines: 78.4 % 398 312
Test Date: 2026-02-17 17:20:33 Functions: 89.5 % 19 17
Legend: Lines:     hit not hit

            Line data    Source code
       1              : /*-------------------------------------------------------------------------
       2              :  *
       3              :  * auth-scram.c
       4              :  *    Server-side implementation of the SASL SCRAM-SHA-256 mechanism.
       5              :  *
       6              :  * See the following RFCs for more details:
       7              :  * - RFC 5802: https://tools.ietf.org/html/rfc5802
       8              :  * - RFC 5803: https://tools.ietf.org/html/rfc5803
       9              :  * - RFC 7677: https://tools.ietf.org/html/rfc7677
      10              :  *
      11              :  * Here are some differences:
      12              :  *
      13              :  * - Username from the authentication exchange is not used. The client
      14              :  *   should send an empty string as the username.
      15              :  *
      16              :  * - If the password isn't valid UTF-8, or contains characters prohibited
      17              :  *   by the SASLprep profile, we skip the SASLprep pre-processing and use
      18              :  *   the raw bytes in calculating the hash.
      19              :  *
      20              :  * - If channel binding is used, the channel binding type is always
      21              :  *   "tls-server-end-point".  The spec says the default is "tls-unique"
      22              :  *   (RFC 5802, section 6.1. Default Channel Binding), but there are some
      23              :  *   problems with that.  Firstly, not all SSL libraries provide an API to
      24              :  *   get the TLS Finished message, required to use "tls-unique".  Secondly,
      25              :  *   "tls-unique" is not specified for TLS v1.3, and as of this writing,
      26              :  *   it's not clear if there will be a replacement.  We could support both
      27              :  *   "tls-server-end-point" and "tls-unique", but for our use case,
      28              :  *   "tls-unique" doesn't really have any advantages.  The main advantage
      29              :  *   of "tls-unique" would be that it works even if the server doesn't
      30              :  *   have a certificate, but PostgreSQL requires a server certificate
      31              :  *   whenever SSL is used, anyway.
      32              :  *
      33              :  *
      34              :  * The password stored in pg_authid consists of the iteration count, salt,
      35              :  * StoredKey and ServerKey.
      36              :  *
      37              :  * SASLprep usage
      38              :  * --------------
      39              :  *
      40              :  * One notable difference to the SCRAM specification is that while the
      41              :  * specification dictates that the password is in UTF-8, and prohibits
      42              :  * certain characters, we are more lenient.  If the password isn't a valid
      43              :  * UTF-8 string, or contains prohibited characters, the raw bytes are used
      44              :  * to calculate the hash instead, without SASLprep processing.  This is
      45              :  * because PostgreSQL supports other encodings too, and the encoding being
      46              :  * used during authentication is undefined (client_encoding isn't set until
      47              :  * after authentication).  In effect, we try to interpret the password as
      48              :  * UTF-8 and apply SASLprep processing, but if it looks invalid, we assume
      49              :  * that it's in some other encoding.
      50              :  *
      51              :  * In the worst case, we misinterpret a password that's in a different
      52              :  * encoding as being Unicode, because it happens to consists entirely of
      53              :  * valid UTF-8 bytes, and we apply Unicode normalization to it.  As long
      54              :  * as we do that consistently, that will not lead to failed logins.
      55              :  * Fortunately, the UTF-8 byte sequences that are ignored by SASLprep
      56              :  * don't correspond to any commonly used characters in any of the other
      57              :  * supported encodings, so it should not lead to any significant loss in
      58              :  * entropy, even if the normalization is incorrectly applied to a
      59              :  * non-UTF-8 password.
      60              :  *
      61              :  * Error handling
      62              :  * --------------
      63              :  *
      64              :  * Don't reveal user information to an unauthenticated client.  We don't
      65              :  * want an attacker to be able to probe whether a particular username is
      66              :  * valid.  In SCRAM, the server has to read the salt and iteration count
      67              :  * from the user's stored secret, and send it to the client.  To avoid
      68              :  * revealing whether a user exists, when the client tries to authenticate
      69              :  * with a username that doesn't exist, or doesn't have a valid SCRAM
      70              :  * secret in pg_authid, we create a fake salt and iteration count
      71              :  * on-the-fly, and proceed with the authentication with that.  In the end,
      72              :  * we'll reject the attempt, as if an incorrect password was given.  When
      73              :  * we are performing a "mock" authentication, the 'doomed' flag in
      74              :  * scram_state is set.
      75              :  *
      76              :  * In the error messages, avoid printing strings from the client, unless
      77              :  * you check that they are pure ASCII.  We don't want an unauthenticated
      78              :  * attacker to be able to spam the logs with characters that are not valid
      79              :  * to the encoding being used, whatever that is.  We cannot avoid that in
      80              :  * general, after logging in, but let's do what we can here.
      81              :  *
      82              :  *
      83              :  * Portions Copyright (c) 1996-2026, PostgreSQL Global Development Group
      84              :  * Portions Copyright (c) 1994, Regents of the University of California
      85              :  *
      86              :  * src/backend/libpq/auth-scram.c
      87              :  *
      88              :  *-------------------------------------------------------------------------
      89              :  */
      90              : #include "postgres.h"
      91              : 
      92              : #include <unistd.h>
      93              : 
      94              : #include "access/xlog.h"
      95              : #include "catalog/pg_control.h"
      96              : #include "common/base64.h"
      97              : #include "common/hmac.h"
      98              : #include "common/saslprep.h"
      99              : #include "common/scram-common.h"
     100              : #include "common/sha2.h"
     101              : #include "libpq/crypt.h"
     102              : #include "libpq/sasl.h"
     103              : #include "libpq/scram.h"
     104              : #include "miscadmin.h"
     105              : 
     106              : static void scram_get_mechanisms(Port *port, StringInfo buf);
     107              : static void *scram_init(Port *port, const char *selected_mech,
     108              :                         const char *shadow_pass);
     109              : static int  scram_exchange(void *opaq, const char *input, int inputlen,
     110              :                            char **output, int *outputlen,
     111              :                            const char **logdetail);
     112              : 
     113              : /* Mechanism declaration */
     114              : const pg_be_sasl_mech pg_be_scram_mech = {
     115              :     scram_get_mechanisms,
     116              :     scram_init,
     117              :     scram_exchange,
     118              : 
     119              :     PG_MAX_SASL_MESSAGE_LENGTH
     120              : };
     121              : 
     122              : /*
     123              :  * Status data for a SCRAM authentication exchange.  This should be kept
     124              :  * internal to this file.
     125              :  */
     126              : typedef enum
     127              : {
     128              :     SCRAM_AUTH_INIT,
     129              :     SCRAM_AUTH_SALT_SENT,
     130              :     SCRAM_AUTH_FINISHED,
     131              : } scram_state_enum;
     132              : 
     133              : typedef struct
     134              : {
     135              :     scram_state_enum state;
     136              : 
     137              :     Port       *port;
     138              :     bool        channel_binding_in_use;
     139              : 
     140              :     /* State data depending on the hash type */
     141              :     pg_cryptohash_type hash_type;
     142              :     int         key_length;
     143              : 
     144              :     int         iterations;
     145              :     char       *salt;           /* base64-encoded */
     146              :     uint8       ClientKey[SCRAM_MAX_KEY_LEN];
     147              :     uint8       StoredKey[SCRAM_MAX_KEY_LEN];
     148              :     uint8       ServerKey[SCRAM_MAX_KEY_LEN];
     149              : 
     150              :     /* Fields of the first message from client */
     151              :     char        cbind_flag;
     152              :     char       *client_first_message_bare;
     153              :     char       *client_username;
     154              :     char       *client_nonce;
     155              : 
     156              :     /* Fields from the last message from client */
     157              :     char       *client_final_message_without_proof;
     158              :     char       *client_final_nonce;
     159              :     uint8       ClientProof[SCRAM_MAX_KEY_LEN];
     160              : 
     161              :     /* Fields generated in the server */
     162              :     char       *server_first_message;
     163              :     char       *server_nonce;
     164              : 
     165              :     /*
     166              :      * If something goes wrong during the authentication, or we are performing
     167              :      * a "mock" authentication (see comments at top of file), the 'doomed'
     168              :      * flag is set.  A reason for the failure, for the server log, is put in
     169              :      * 'logdetail'.
     170              :      */
     171              :     bool        doomed;
     172              :     char       *logdetail;
     173              : } scram_state;
     174              : 
     175              : static void read_client_first_message(scram_state *state, const char *input);
     176              : static void read_client_final_message(scram_state *state, const char *input);
     177              : static char *build_server_first_message(scram_state *state);
     178              : static char *build_server_final_message(scram_state *state);
     179              : static bool verify_client_proof(scram_state *state);
     180              : static bool verify_final_nonce(scram_state *state);
     181              : static void mock_scram_secret(const char *username, pg_cryptohash_type *hash_type,
     182              :                               int *iterations, int *key_length, char **salt,
     183              :                               uint8 *stored_key, uint8 *server_key);
     184              : static bool is_scram_printable(char *p);
     185              : static char *sanitize_char(char c);
     186              : static char *sanitize_str(const char *s);
     187              : static uint8 *scram_mock_salt(const char *username,
     188              :                               pg_cryptohash_type hash_type,
     189              :                               int key_length);
     190              : 
     191              : /*
     192              :  * The number of iterations to use when generating new secrets.
     193              :  */
     194              : int         scram_sha_256_iterations = SCRAM_SHA_256_DEFAULT_ITERATIONS;
     195              : 
     196              : /*
     197              :  * Get a list of SASL mechanisms that this module supports.
     198              :  *
     199              :  * For the convenience of building the FE/BE packet that lists the
     200              :  * mechanisms, the names are appended to the given StringInfo buffer,
     201              :  * separated by '\0' bytes.
     202              :  */
     203              : static void
     204           70 : scram_get_mechanisms(Port *port, StringInfo buf)
     205              : {
     206              :     /*
     207              :      * Advertise the mechanisms in decreasing order of importance.  So the
     208              :      * channel-binding variants go first, if they are supported.  Channel
     209              :      * binding is only supported with SSL.
     210              :      */
     211              : #ifdef USE_SSL
     212           70 :     if (port->ssl_in_use)
     213              :     {
     214            6 :         appendStringInfoString(buf, SCRAM_SHA_256_PLUS_NAME);
     215            6 :         appendStringInfoChar(buf, '\0');
     216              :     }
     217              : #endif
     218           70 :     appendStringInfoString(buf, SCRAM_SHA_256_NAME);
     219           70 :     appendStringInfoChar(buf, '\0');
     220           70 : }
     221              : 
     222              : /*
     223              :  * Initialize a new SCRAM authentication exchange status tracker.  This
     224              :  * needs to be called before doing any exchange.  It will be filled later
     225              :  * after the beginning of the exchange with authentication information.
     226              :  *
     227              :  * 'selected_mech' identifies the SASL mechanism that the client selected.
     228              :  * It should be one of the mechanisms that we support, as returned by
     229              :  * scram_get_mechanisms().
     230              :  *
     231              :  * 'shadow_pass' is the role's stored secret, from pg_authid.rolpassword.
     232              :  * The username was provided by the client in the startup message, and is
     233              :  * available in port->user_name.  If 'shadow_pass' is NULL, we still perform
     234              :  * an authentication exchange, but it will fail, as if an incorrect password
     235              :  * was given.
     236              :  */
     237              : static void *
     238           59 : scram_init(Port *port, const char *selected_mech, const char *shadow_pass)
     239              : {
     240              :     scram_state *state;
     241              :     bool        got_secret;
     242              : 
     243           59 :     state = palloc0_object(scram_state);
     244           59 :     state->port = port;
     245           59 :     state->state = SCRAM_AUTH_INIT;
     246              : 
     247              :     /*
     248              :      * Parse the selected mechanism.
     249              :      *
     250              :      * Note that if we don't support channel binding, or if we're not using
     251              :      * SSL at all, we would not have advertised the PLUS variant in the first
     252              :      * place.  If the client nevertheless tries to select it, it's a protocol
     253              :      * violation like selecting any other SASL mechanism we don't support.
     254              :      */
     255              : #ifdef USE_SSL
     256           59 :     if (strcmp(selected_mech, SCRAM_SHA_256_PLUS_NAME) == 0 && port->ssl_in_use)
     257            4 :         state->channel_binding_in_use = true;
     258              :     else
     259              : #endif
     260           55 :     if (strcmp(selected_mech, SCRAM_SHA_256_NAME) == 0)
     261           55 :         state->channel_binding_in_use = false;
     262              :     else
     263            0 :         ereport(ERROR,
     264              :                 (errcode(ERRCODE_PROTOCOL_VIOLATION),
     265              :                  errmsg("client selected an invalid SASL authentication mechanism")));
     266              : 
     267              :     /*
     268              :      * Parse the stored secret.
     269              :      */
     270           59 :     if (shadow_pass)
     271              :     {
     272           58 :         int         password_type = get_password_type(shadow_pass);
     273              : 
     274           58 :         if (password_type == PASSWORD_TYPE_SCRAM_SHA_256)
     275              :         {
     276           57 :             if (parse_scram_secret(shadow_pass, &state->iterations,
     277              :                                    &state->hash_type, &state->key_length,
     278              :                                    &state->salt,
     279           57 :                                    state->StoredKey,
     280           57 :                                    state->ServerKey))
     281           57 :                 got_secret = true;
     282              :             else
     283              :             {
     284              :                 /*
     285              :                  * The password looked like a SCRAM secret, but could not be
     286              :                  * parsed.
     287              :                  */
     288            0 :                 ereport(LOG,
     289              :                         (errmsg("invalid SCRAM secret for user \"%s\"",
     290              :                                 state->port->user_name)));
     291            0 :                 got_secret = false;
     292              :             }
     293              :         }
     294              :         else
     295              :         {
     296              :             /*
     297              :              * The user doesn't have SCRAM secret. (You cannot do SCRAM
     298              :              * authentication with an MD5 hash.)
     299              :              */
     300            2 :             state->logdetail = psprintf(_("User \"%s\" does not have a valid SCRAM secret."),
     301            1 :                                         state->port->user_name);
     302            1 :             got_secret = false;
     303              :         }
     304              :     }
     305              :     else
     306              :     {
     307              :         /*
     308              :          * The caller requested us to perform a dummy authentication.  This is
     309              :          * considered normal, since the caller requested it, so don't set log
     310              :          * detail.
     311              :          */
     312            1 :         got_secret = false;
     313              :     }
     314              : 
     315              :     /*
     316              :      * If the user did not have a valid SCRAM secret, we still go through the
     317              :      * motions with a mock one, and fail as if the client supplied an
     318              :      * incorrect password.  This is to avoid revealing information to an
     319              :      * attacker.
     320              :      */
     321           59 :     if (!got_secret)
     322              :     {
     323            2 :         mock_scram_secret(state->port->user_name, &state->hash_type,
     324              :                           &state->iterations, &state->key_length,
     325              :                           &state->salt,
     326            2 :                           state->StoredKey, state->ServerKey);
     327            2 :         state->doomed = true;
     328              :     }
     329              : 
     330           59 :     return state;
     331              : }
     332              : 
     333              : /*
     334              :  * Continue a SCRAM authentication exchange.
     335              :  *
     336              :  * 'input' is the SCRAM payload sent by the client.  On the first call,
     337              :  * 'input' contains the "Initial Client Response" that the client sent as
     338              :  * part of the SASLInitialResponse message, or NULL if no Initial Client
     339              :  * Response was given.  (The SASL specification distinguishes between an
     340              :  * empty response and non-existing one.)  On subsequent calls, 'input'
     341              :  * cannot be NULL.  For convenience in this function, the caller must
     342              :  * ensure that there is a null terminator at input[inputlen].
     343              :  *
     344              :  * The next message to send to client is saved in 'output', for a length
     345              :  * of 'outputlen'.  In the case of an error, optionally store a palloc'd
     346              :  * string at *logdetail that will be sent to the postmaster log (but not
     347              :  * the client).
     348              :  */
     349              : static int
     350          118 : scram_exchange(void *opaq, const char *input, int inputlen,
     351              :                char **output, int *outputlen, const char **logdetail)
     352              : {
     353          118 :     scram_state *state = (scram_state *) opaq;
     354              :     int         result;
     355              : 
     356          118 :     *output = NULL;
     357              : 
     358              :     /*
     359              :      * If the client didn't include an "Initial Client Response" in the
     360              :      * SASLInitialResponse message, send an empty challenge, to which the
     361              :      * client will respond with the same data that usually comes in the
     362              :      * Initial Client Response.
     363              :      */
     364          118 :     if (input == NULL)
     365              :     {
     366              :         Assert(state->state == SCRAM_AUTH_INIT);
     367              : 
     368            0 :         *output = pstrdup("");
     369            0 :         *outputlen = 0;
     370            0 :         return PG_SASL_EXCHANGE_CONTINUE;
     371              :     }
     372              : 
     373              :     /*
     374              :      * Check that the input length agrees with the string length of the input.
     375              :      * We can ignore inputlen after this.
     376              :      */
     377          118 :     if (inputlen == 0)
     378            0 :         ereport(ERROR,
     379              :                 (errcode(ERRCODE_PROTOCOL_VIOLATION),
     380              :                  errmsg("malformed SCRAM message"),
     381              :                  errdetail("The message is empty.")));
     382          118 :     if (inputlen != strlen(input))
     383            0 :         ereport(ERROR,
     384              :                 (errcode(ERRCODE_PROTOCOL_VIOLATION),
     385              :                  errmsg("malformed SCRAM message"),
     386              :                  errdetail("Message length does not match input length.")));
     387              : 
     388          118 :     switch (state->state)
     389              :     {
     390           59 :         case SCRAM_AUTH_INIT:
     391              : 
     392              :             /*
     393              :              * Initialization phase.  Receive the first message from client
     394              :              * and be sure that it parsed correctly.  Then send the challenge
     395              :              * to the client.
     396              :              */
     397           59 :             read_client_first_message(state, input);
     398              : 
     399              :             /* prepare message to send challenge */
     400           59 :             *output = build_server_first_message(state);
     401              : 
     402           59 :             state->state = SCRAM_AUTH_SALT_SENT;
     403           59 :             result = PG_SASL_EXCHANGE_CONTINUE;
     404           59 :             break;
     405              : 
     406           59 :         case SCRAM_AUTH_SALT_SENT:
     407              : 
     408              :             /*
     409              :              * Final phase for the server.  Receive the response to the
     410              :              * challenge previously sent, verify, and let the client know that
     411              :              * everything went well (or not).
     412              :              */
     413           59 :             read_client_final_message(state, input);
     414              : 
     415           59 :             if (!verify_final_nonce(state))
     416            0 :                 ereport(ERROR,
     417              :                         (errcode(ERRCODE_PROTOCOL_VIOLATION),
     418              :                          errmsg("invalid SCRAM response"),
     419              :                          errdetail("Nonce does not match.")));
     420              : 
     421              :             /*
     422              :              * Now check the final nonce and the client proof.
     423              :              *
     424              :              * If we performed a "mock" authentication that we knew would fail
     425              :              * from the get go, this is where we fail.
     426              :              *
     427              :              * The SCRAM specification includes an error code,
     428              :              * "invalid-proof", for authentication failure, but it also allows
     429              :              * erroring out in an application-specific way.  We choose to do
     430              :              * the latter, so that the error message for invalid password is
     431              :              * the same for all authentication methods.  The caller will call
     432              :              * ereport(), when we return PG_SASL_EXCHANGE_FAILURE with no
     433              :              * output.
     434              :              *
     435              :              * NB: the order of these checks is intentional.  We calculate the
     436              :              * client proof even in a mock authentication, even though it's
     437              :              * bound to fail, to thwart timing attacks to determine if a role
     438              :              * with the given name exists or not.
     439              :              */
     440           59 :             if (!verify_client_proof(state) || state->doomed)
     441              :             {
     442            6 :                 result = PG_SASL_EXCHANGE_FAILURE;
     443            6 :                 break;
     444              :             }
     445              : 
     446              :             /* Build final message for client */
     447           53 :             *output = build_server_final_message(state);
     448              : 
     449              :             /* Success! */
     450           53 :             result = PG_SASL_EXCHANGE_SUCCESS;
     451           53 :             state->state = SCRAM_AUTH_FINISHED;
     452           53 :             break;
     453              : 
     454            0 :         default:
     455            0 :             elog(ERROR, "invalid SCRAM exchange state");
     456              :             result = PG_SASL_EXCHANGE_FAILURE;
     457              :     }
     458              : 
     459          118 :     if (result == PG_SASL_EXCHANGE_FAILURE && state->logdetail && logdetail)
     460            1 :         *logdetail = state->logdetail;
     461              : 
     462          118 :     if (*output)
     463          112 :         *outputlen = strlen(*output);
     464              : 
     465          118 :     if (result == PG_SASL_EXCHANGE_SUCCESS && state->state == SCRAM_AUTH_FINISHED)
     466              :     {
     467           53 :         memcpy(MyProcPort->scram_ClientKey, state->ClientKey, sizeof(MyProcPort->scram_ClientKey));
     468           53 :         memcpy(MyProcPort->scram_ServerKey, state->ServerKey, sizeof(MyProcPort->scram_ServerKey));
     469           53 :         MyProcPort->has_scram_keys = true;
     470              :     }
     471              : 
     472          118 :     return result;
     473              : }
     474              : 
     475              : /*
     476              :  * Construct a SCRAM secret, for storing in pg_authid.rolpassword.
     477              :  *
     478              :  * The result is palloc'd, so caller is responsible for freeing it.
     479              :  */
     480              : char *
     481           56 : pg_be_scram_build_secret(const char *password)
     482              : {
     483              :     char       *prep_password;
     484              :     pg_saslprep_rc rc;
     485              :     uint8       saltbuf[SCRAM_DEFAULT_SALT_LEN];
     486              :     char       *result;
     487           56 :     const char *errstr = NULL;
     488              : 
     489              :     /*
     490              :      * Normalize the password with SASLprep.  If that doesn't work, because
     491              :      * the password isn't valid UTF-8 or contains prohibited characters, just
     492              :      * proceed with the original password.  (See comments at top of file.)
     493              :      */
     494           56 :     rc = pg_saslprep(password, &prep_password);
     495           56 :     if (rc == SASLPREP_SUCCESS)
     496           55 :         password = (const char *) prep_password;
     497              : 
     498              :     /* Generate random salt */
     499           56 :     if (!pg_strong_random(saltbuf, SCRAM_DEFAULT_SALT_LEN))
     500            0 :         ereport(ERROR,
     501              :                 (errcode(ERRCODE_INTERNAL_ERROR),
     502              :                  errmsg("could not generate random salt")));
     503              : 
     504           56 :     result = scram_build_secret(PG_SHA256, SCRAM_SHA_256_KEY_LEN,
     505              :                                 saltbuf, SCRAM_DEFAULT_SALT_LEN,
     506              :                                 scram_sha_256_iterations, password,
     507              :                                 &errstr);
     508              : 
     509           56 :     if (prep_password)
     510           55 :         pfree(prep_password);
     511              : 
     512           56 :     return result;
     513              : }
     514              : 
     515              : /*
     516              :  * Verify a plaintext password against a SCRAM secret.  This is used when
     517              :  * performing plaintext password authentication for a user that has a SCRAM
     518              :  * secret stored in pg_authid.
     519              :  */
     520              : bool
     521           27 : scram_verify_plain_password(const char *username, const char *password,
     522              :                             const char *secret)
     523              : {
     524              :     char       *encoded_salt;
     525              :     uint8      *salt;
     526              :     int         saltlen;
     527              :     int         iterations;
     528           27 :     int         key_length = 0;
     529              :     pg_cryptohash_type hash_type;
     530              :     uint8       salted_password[SCRAM_MAX_KEY_LEN];
     531              :     uint8       stored_key[SCRAM_MAX_KEY_LEN];
     532              :     uint8       server_key[SCRAM_MAX_KEY_LEN];
     533              :     uint8       computed_key[SCRAM_MAX_KEY_LEN];
     534              :     char       *prep_password;
     535              :     pg_saslprep_rc rc;
     536           27 :     const char *errstr = NULL;
     537              : 
     538           27 :     if (!parse_scram_secret(secret, &iterations, &hash_type, &key_length,
     539              :                             &encoded_salt, stored_key, server_key))
     540              :     {
     541              :         /*
     542              :          * The password looked like a SCRAM secret, but could not be parsed.
     543              :          */
     544            0 :         ereport(LOG,
     545              :                 (errmsg("invalid SCRAM secret for user \"%s\"", username)));
     546            0 :         return false;
     547              :     }
     548              : 
     549           27 :     saltlen = pg_b64_dec_len(strlen(encoded_salt));
     550           27 :     salt = palloc(saltlen);
     551           27 :     saltlen = pg_b64_decode(encoded_salt, strlen(encoded_salt), salt,
     552              :                             saltlen);
     553           27 :     if (saltlen < 0)
     554              :     {
     555            0 :         ereport(LOG,
     556              :                 (errmsg("invalid SCRAM secret for user \"%s\"", username)));
     557            0 :         return false;
     558              :     }
     559              : 
     560              :     /* Normalize the password */
     561           27 :     rc = pg_saslprep(password, &prep_password);
     562           27 :     if (rc == SASLPREP_SUCCESS)
     563           27 :         password = prep_password;
     564              : 
     565              :     /* Compute Server Key based on the user-supplied plaintext password */
     566           27 :     if (scram_SaltedPassword(password, hash_type, key_length,
     567              :                              salt, saltlen, iterations,
     568           27 :                              salted_password, &errstr) < 0 ||
     569           27 :         scram_ServerKey(salted_password, hash_type, key_length,
     570              :                         computed_key, &errstr) < 0)
     571              :     {
     572            0 :         elog(ERROR, "could not compute server key: %s", errstr);
     573              :     }
     574              : 
     575           27 :     if (prep_password)
     576           27 :         pfree(prep_password);
     577              : 
     578              :     /*
     579              :      * Compare the secret's Server Key with the one computed from the
     580              :      * user-supplied password.
     581              :      */
     582           27 :     return memcmp(computed_key, server_key, key_length) == 0;
     583              : }
     584              : 
     585              : 
     586              : /*
     587              :  * Parse and validate format of given SCRAM secret.
     588              :  *
     589              :  * On success, the iteration count, salt, stored key, and server key are
     590              :  * extracted from the secret, and returned to the caller.  For 'stored_key'
     591              :  * and 'server_key', the caller must pass pre-allocated buffers of size
     592              :  * SCRAM_MAX_KEY_LEN.  Salt is returned as a base64-encoded, null-terminated
     593              :  * string.  The buffer for the salt is palloc'd by this function.
     594              :  *
     595              :  * Returns true if the SCRAM secret has been parsed, and false otherwise.
     596              :  */
     597              : bool
     598          453 : parse_scram_secret(const char *secret, int *iterations,
     599              :                    pg_cryptohash_type *hash_type, int *key_length,
     600              :                    char **salt, uint8 *stored_key, uint8 *server_key)
     601              : {
     602              :     char       *v;
     603              :     char       *p;
     604              :     char       *scheme_str;
     605              :     char       *salt_str;
     606              :     char       *iterations_str;
     607              :     char       *storedkey_str;
     608              :     char       *serverkey_str;
     609              :     int         decoded_len;
     610              :     uint8      *decoded_salt_buf;
     611              :     uint8      *decoded_stored_buf;
     612              :     uint8      *decoded_server_buf;
     613              : 
     614              :     /*
     615              :      * The secret is of form:
     616              :      *
     617              :      * SCRAM-SHA-256$<iterations>:<salt>$<storedkey>:<serverkey>
     618              :      */
     619          453 :     v = pstrdup(secret);
     620          453 :     scheme_str = strsep(&v, "$");
     621          453 :     if (v == NULL)
     622          121 :         goto invalid_secret;
     623          332 :     iterations_str = strsep(&v, ":");
     624          332 :     if (v == NULL)
     625            6 :         goto invalid_secret;
     626          326 :     salt_str = strsep(&v, "$");
     627          326 :     if (v == NULL)
     628            0 :         goto invalid_secret;
     629          326 :     storedkey_str = strsep(&v, ":");
     630          326 :     if (v == NULL)
     631            0 :         goto invalid_secret;
     632          326 :     serverkey_str = v;
     633              : 
     634              :     /* Parse the fields */
     635          326 :     if (strcmp(scheme_str, "SCRAM-SHA-256") != 0)
     636            0 :         goto invalid_secret;
     637          326 :     *hash_type = PG_SHA256;
     638          326 :     *key_length = SCRAM_SHA_256_KEY_LEN;
     639              : 
     640          326 :     errno = 0;
     641          326 :     *iterations = strtol(iterations_str, &p, 10);
     642          326 :     if (*p || errno != 0)
     643            0 :         goto invalid_secret;
     644              : 
     645              :     /*
     646              :      * Verify that the salt is in Base64-encoded format, by decoding it,
     647              :      * although we return the encoded version to the caller.
     648              :      */
     649          326 :     decoded_len = pg_b64_dec_len(strlen(salt_str));
     650          326 :     decoded_salt_buf = palloc(decoded_len);
     651          326 :     decoded_len = pg_b64_decode(salt_str, strlen(salt_str),
     652              :                                 decoded_salt_buf, decoded_len);
     653          326 :     if (decoded_len < 0)
     654            0 :         goto invalid_secret;
     655          326 :     *salt = pstrdup(salt_str);
     656              : 
     657              :     /*
     658              :      * Decode StoredKey and ServerKey.
     659              :      */
     660          326 :     decoded_len = pg_b64_dec_len(strlen(storedkey_str));
     661          326 :     decoded_stored_buf = palloc(decoded_len);
     662          326 :     decoded_len = pg_b64_decode(storedkey_str, strlen(storedkey_str),
     663              :                                 decoded_stored_buf, decoded_len);
     664          326 :     if (decoded_len != *key_length)
     665            6 :         goto invalid_secret;
     666          320 :     memcpy(stored_key, decoded_stored_buf, *key_length);
     667              : 
     668          320 :     decoded_len = pg_b64_dec_len(strlen(serverkey_str));
     669          320 :     decoded_server_buf = palloc(decoded_len);
     670          320 :     decoded_len = pg_b64_decode(serverkey_str, strlen(serverkey_str),
     671              :                                 decoded_server_buf, decoded_len);
     672          320 :     if (decoded_len != *key_length)
     673            6 :         goto invalid_secret;
     674          314 :     memcpy(server_key, decoded_server_buf, *key_length);
     675              : 
     676          314 :     return true;
     677              : 
     678          139 : invalid_secret:
     679          139 :     *salt = NULL;
     680          139 :     return false;
     681              : }
     682              : 
     683              : /*
     684              :  * Generate plausible SCRAM secret parameters for mock authentication.
     685              :  *
     686              :  * In a normal authentication, these are extracted from the secret
     687              :  * stored in the server.  This function generates values that look
     688              :  * realistic, for when there is no stored secret, using SCRAM-SHA-256.
     689              :  *
     690              :  * Like in parse_scram_secret(), for 'stored_key' and 'server_key', the
     691              :  * caller must pass pre-allocated buffers of size SCRAM_MAX_KEY_LEN, and
     692              :  * the buffer for the salt is palloc'd by this function.
     693              :  */
     694              : static void
     695            2 : mock_scram_secret(const char *username, pg_cryptohash_type *hash_type,
     696              :                   int *iterations, int *key_length, char **salt,
     697              :                   uint8 *stored_key, uint8 *server_key)
     698              : {
     699              :     uint8      *raw_salt;
     700              :     char       *encoded_salt;
     701              :     int         encoded_len;
     702              : 
     703              :     /* Enforce the use of SHA-256, which would be realistic enough */
     704            2 :     *hash_type = PG_SHA256;
     705            2 :     *key_length = SCRAM_SHA_256_KEY_LEN;
     706              : 
     707              :     /*
     708              :      * Generate deterministic salt.
     709              :      *
     710              :      * Note that we cannot reveal any information to an attacker here so the
     711              :      * error messages need to remain generic.  This should never fail anyway
     712              :      * as the salt generated for mock authentication uses the cluster's nonce
     713              :      * value.
     714              :      */
     715            2 :     raw_salt = scram_mock_salt(username, *hash_type, *key_length);
     716            2 :     if (raw_salt == NULL)
     717            0 :         elog(ERROR, "could not encode salt");
     718              : 
     719            2 :     encoded_len = pg_b64_enc_len(SCRAM_DEFAULT_SALT_LEN);
     720              :     /* don't forget the zero-terminator */
     721            2 :     encoded_salt = (char *) palloc(encoded_len + 1);
     722            2 :     encoded_len = pg_b64_encode(raw_salt, SCRAM_DEFAULT_SALT_LEN, encoded_salt,
     723              :                                 encoded_len);
     724              : 
     725            2 :     if (encoded_len < 0)
     726            0 :         elog(ERROR, "could not encode salt");
     727            2 :     encoded_salt[encoded_len] = '\0';
     728              : 
     729            2 :     *salt = encoded_salt;
     730            2 :     *iterations = SCRAM_SHA_256_DEFAULT_ITERATIONS;
     731              : 
     732              :     /* StoredKey and ServerKey are not used in a doomed authentication */
     733            2 :     memset(stored_key, 0, SCRAM_MAX_KEY_LEN);
     734            2 :     memset(server_key, 0, SCRAM_MAX_KEY_LEN);
     735            2 : }
     736              : 
     737              : /*
     738              :  * Read the value in a given SCRAM exchange message for given attribute.
     739              :  */
     740              : static char *
     741          240 : read_attr_value(char **input, char attr)
     742              : {
     743          240 :     char       *begin = *input;
     744              :     char       *end;
     745              : 
     746          240 :     if (*begin != attr)
     747            0 :         ereport(ERROR,
     748              :                 (errcode(ERRCODE_PROTOCOL_VIOLATION),
     749              :                  errmsg("malformed SCRAM message"),
     750              :                  errdetail("Expected attribute \"%c\" but found \"%s\".",
     751              :                            attr, sanitize_char(*begin))));
     752          240 :     begin++;
     753              : 
     754          240 :     if (*begin != '=')
     755            0 :         ereport(ERROR,
     756              :                 (errcode(ERRCODE_PROTOCOL_VIOLATION),
     757              :                  errmsg("malformed SCRAM message"),
     758              :                  errdetail("Expected character \"=\" for attribute \"%c\".", attr)));
     759          240 :     begin++;
     760              : 
     761          240 :     end = begin;
     762         5092 :     while (*end && *end != ',')
     763         4852 :         end++;
     764              : 
     765          240 :     if (*end)
     766              :     {
     767          181 :         *end = '\0';
     768          181 :         *input = end + 1;
     769              :     }
     770              :     else
     771           59 :         *input = end;
     772              : 
     773          240 :     return begin;
     774              : }
     775              : 
     776              : static bool
     777           59 : is_scram_printable(char *p)
     778              : {
     779              :     /*------
     780              :      * Printable characters, as defined by SCRAM spec: (RFC 5802)
     781              :      *
     782              :      *  printable       = %x21-2B / %x2D-7E
     783              :      *                    ;; Printable ASCII except ",".
     784              :      *                    ;; Note that any "printable" is also
     785              :      *                    ;; a valid "value".
     786              :      *------
     787              :      */
     788         1475 :     for (; *p; p++)
     789              :     {
     790         1416 :         if (*p < 0x21 || *p > 0x7E || *p == 0x2C /* comma */ )
     791            0 :             return false;
     792              :     }
     793           59 :     return true;
     794              : }
     795              : 
     796              : /*
     797              :  * Convert an arbitrary byte to printable form.  For error messages.
     798              :  *
     799              :  * If it's a printable ASCII character, print it as a single character.
     800              :  * otherwise, print it in hex.
     801              :  *
     802              :  * The returned pointer points to a static buffer.
     803              :  */
     804              : static char *
     805            0 : sanitize_char(char c)
     806              : {
     807              :     static char buf[5];
     808              : 
     809            0 :     if (c >= 0x21 && c <= 0x7E)
     810            0 :         snprintf(buf, sizeof(buf), "'%c'", c);
     811              :     else
     812            0 :         snprintf(buf, sizeof(buf), "0x%02x", (unsigned char) c);
     813            0 :     return buf;
     814              : }
     815              : 
     816              : /*
     817              :  * Convert an arbitrary string to printable form, for error messages.
     818              :  *
     819              :  * Anything that's not a printable ASCII character is replaced with
     820              :  * '?', and the string is truncated at 30 characters.
     821              :  *
     822              :  * The returned pointer points to a static buffer.
     823              :  */
     824              : static char *
     825            0 : sanitize_str(const char *s)
     826              : {
     827              :     static char buf[30 + 1];
     828              :     int         i;
     829              : 
     830            0 :     for (i = 0; i < sizeof(buf) - 1; i++)
     831              :     {
     832            0 :         char        c = s[i];
     833              : 
     834            0 :         if (c == '\0')
     835            0 :             break;
     836              : 
     837            0 :         if (c >= 0x21 && c <= 0x7E)
     838            0 :             buf[i] = c;
     839              :         else
     840            0 :             buf[i] = '?';
     841              :     }
     842            0 :     buf[i] = '\0';
     843            0 :     return buf;
     844              : }
     845              : 
     846              : /*
     847              :  * Read the next attribute and value in a SCRAM exchange message.
     848              :  *
     849              :  * The attribute character is set in *attr_p, the attribute value is the
     850              :  * return value.
     851              :  */
     852              : static char *
     853           59 : read_any_attr(char **input, char *attr_p)
     854              : {
     855           59 :     char       *begin = *input;
     856              :     char       *end;
     857           59 :     char        attr = *begin;
     858              : 
     859           59 :     if (attr == '\0')
     860            0 :         ereport(ERROR,
     861              :                 (errcode(ERRCODE_PROTOCOL_VIOLATION),
     862              :                  errmsg("malformed SCRAM message"),
     863              :                  errdetail("Attribute expected, but found end of string.")));
     864              : 
     865              :     /*------
     866              :      * attr-val        = ALPHA "=" value
     867              :      *                   ;; Generic syntax of any attribute sent
     868              :      *                   ;; by server or client
     869              :      *------
     870              :      */
     871           59 :     if (!((attr >= 'A' && attr <= 'Z') ||
     872           59 :           (attr >= 'a' && attr <= 'z')))
     873            0 :         ereport(ERROR,
     874              :                 (errcode(ERRCODE_PROTOCOL_VIOLATION),
     875              :                  errmsg("malformed SCRAM message"),
     876              :                  errdetail("Attribute expected, but found invalid character \"%s\".",
     877              :                            sanitize_char(attr))));
     878           59 :     if (attr_p)
     879           59 :         *attr_p = attr;
     880           59 :     begin++;
     881              : 
     882           59 :     if (*begin != '=')
     883            0 :         ereport(ERROR,
     884              :                 (errcode(ERRCODE_PROTOCOL_VIOLATION),
     885              :                  errmsg("malformed SCRAM message"),
     886              :                  errdetail("Expected character \"=\" for attribute \"%c\".", attr)));
     887           59 :     begin++;
     888              : 
     889           59 :     end = begin;
     890         2655 :     while (*end && *end != ',')
     891         2596 :         end++;
     892              : 
     893           59 :     if (*end)
     894              :     {
     895            0 :         *end = '\0';
     896            0 :         *input = end + 1;
     897              :     }
     898              :     else
     899           59 :         *input = end;
     900              : 
     901           59 :     return begin;
     902              : }
     903              : 
     904              : /*
     905              :  * Read and parse the first message from client in the context of a SCRAM
     906              :  * authentication exchange message.
     907              :  *
     908              :  * At this stage, any errors will be reported directly with ereport(ERROR).
     909              :  */
     910              : static void
     911           59 : read_client_first_message(scram_state *state, const char *input)
     912              : {
     913           59 :     char       *p = pstrdup(input);
     914              :     char       *channel_binding_type;
     915              : 
     916              : 
     917              :     /*------
     918              :      * The syntax for the client-first-message is: (RFC 5802)
     919              :      *
     920              :      * saslname        = 1*(value-safe-char / "=2C" / "=3D")
     921              :      *                   ;; Conforms to <value>.
     922              :      *
     923              :      * authzid         = "a=" saslname
     924              :      *                   ;; Protocol specific.
     925              :      *
     926              :      * cb-name         = 1*(ALPHA / DIGIT / "." / "-")
     927              :      *                    ;; See RFC 5056, Section 7.
     928              :      *                    ;; E.g., "tls-server-end-point" or
     929              :      *                    ;; "tls-unique".
     930              :      *
     931              :      * gs2-cbind-flag  = ("p=" cb-name) / "n" / "y"
     932              :      *                   ;; "n" -> client doesn't support channel binding.
     933              :      *                   ;; "y" -> client does support channel binding
     934              :      *                   ;;        but thinks the server does not.
     935              :      *                   ;; "p" -> client requires channel binding.
     936              :      *                   ;; The selected channel binding follows "p=".
     937              :      *
     938              :      * gs2-header      = gs2-cbind-flag "," [ authzid ] ","
     939              :      *                   ;; GS2 header for SCRAM
     940              :      *                   ;; (the actual GS2 header includes an optional
     941              :      *                   ;; flag to indicate that the GSS mechanism is not
     942              :      *                   ;; "standard", but since SCRAM is "standard", we
     943              :      *                   ;; don't include that flag).
     944              :      *
     945              :      * username        = "n=" saslname
     946              :      *                   ;; Usernames are prepared using SASLprep.
     947              :      *
     948              :      * reserved-mext  = "m=" 1*(value-char)
     949              :      *                   ;; Reserved for signaling mandatory extensions.
     950              :      *                   ;; The exact syntax will be defined in
     951              :      *                   ;; the future.
     952              :      *
     953              :      * nonce           = "r=" c-nonce [s-nonce]
     954              :      *                   ;; Second part provided by server.
     955              :      *
     956              :      * c-nonce         = printable
     957              :      *
     958              :      * client-first-message-bare =
     959              :      *                   [reserved-mext ","]
     960              :      *                   username "," nonce ["," extensions]
     961              :      *
     962              :      * client-first-message =
     963              :      *                   gs2-header client-first-message-bare
     964              :      *
     965              :      * For example:
     966              :      * n,,n=user,r=fyko+d2lbbFgONRv9qkxdawL
     967              :      *
     968              :      * The "n,," in the beginning means that the client doesn't support
     969              :      * channel binding, and no authzid is given.  "n=user" is the username.
     970              :      * However, in PostgreSQL the username is sent in the startup packet, and
     971              :      * the username in the SCRAM exchange is ignored.  libpq always sends it
     972              :      * as an empty string.  The last part, "r=fyko+d2lbbFgONRv9qkxdawL" is
     973              :      * the client nonce.
     974              :      *------
     975              :      */
     976              : 
     977              :     /*
     978              :      * Read gs2-cbind-flag.  (For details see also RFC 5802 Section 6 "Channel
     979              :      * Binding".)
     980              :      */
     981           59 :     state->cbind_flag = *p;
     982           59 :     switch (*p)
     983              :     {
     984           55 :         case 'n':
     985              : 
     986              :             /*
     987              :              * The client does not support channel binding or has simply
     988              :              * decided to not use it.  In that case just let it go.
     989              :              */
     990           55 :             if (state->channel_binding_in_use)
     991            0 :                 ereport(ERROR,
     992              :                         (errcode(ERRCODE_PROTOCOL_VIOLATION),
     993              :                          errmsg("malformed SCRAM message"),
     994              :                          errdetail("The client selected SCRAM-SHA-256-PLUS, but the SCRAM message does not include channel binding data.")));
     995              : 
     996           55 :             p++;
     997           55 :             if (*p != ',')
     998            0 :                 ereport(ERROR,
     999              :                         (errcode(ERRCODE_PROTOCOL_VIOLATION),
    1000              :                          errmsg("malformed SCRAM message"),
    1001              :                          errdetail("Comma expected, but found character \"%s\".",
    1002              :                                    sanitize_char(*p))));
    1003           55 :             p++;
    1004           55 :             break;
    1005            0 :         case 'y':
    1006              : 
    1007              :             /*
    1008              :              * The client supports channel binding and thinks that the server
    1009              :              * does not.  In this case, the server must fail authentication if
    1010              :              * it supports channel binding.
    1011              :              */
    1012            0 :             if (state->channel_binding_in_use)
    1013            0 :                 ereport(ERROR,
    1014              :                         (errcode(ERRCODE_PROTOCOL_VIOLATION),
    1015              :                          errmsg("malformed SCRAM message"),
    1016              :                          errdetail("The client selected SCRAM-SHA-256-PLUS, but the SCRAM message does not include channel binding data.")));
    1017              : 
    1018              : #ifdef USE_SSL
    1019            0 :             if (state->port->ssl_in_use)
    1020            0 :                 ereport(ERROR,
    1021              :                         (errcode(ERRCODE_INVALID_AUTHORIZATION_SPECIFICATION),
    1022              :                          errmsg("SCRAM channel binding negotiation error"),
    1023              :                          errdetail("The client supports SCRAM channel binding but thinks the server does not.  "
    1024              :                                    "However, this server does support channel binding.")));
    1025              : #endif
    1026            0 :             p++;
    1027            0 :             if (*p != ',')
    1028            0 :                 ereport(ERROR,
    1029              :                         (errcode(ERRCODE_PROTOCOL_VIOLATION),
    1030              :                          errmsg("malformed SCRAM message"),
    1031              :                          errdetail("Comma expected, but found character \"%s\".",
    1032              :                                    sanitize_char(*p))));
    1033            0 :             p++;
    1034            0 :             break;
    1035            4 :         case 'p':
    1036              : 
    1037              :             /*
    1038              :              * The client requires channel binding.  Channel binding type
    1039              :              * follows, e.g., "p=tls-server-end-point".
    1040              :              */
    1041            4 :             if (!state->channel_binding_in_use)
    1042            0 :                 ereport(ERROR,
    1043              :                         (errcode(ERRCODE_PROTOCOL_VIOLATION),
    1044              :                          errmsg("malformed SCRAM message"),
    1045              :                          errdetail("The client selected SCRAM-SHA-256 without channel binding, but the SCRAM message includes channel binding data.")));
    1046              : 
    1047            4 :             channel_binding_type = read_attr_value(&p, 'p');
    1048              : 
    1049              :             /*
    1050              :              * The only channel binding type we support is
    1051              :              * tls-server-end-point.
    1052              :              */
    1053            4 :             if (strcmp(channel_binding_type, "tls-server-end-point") != 0)
    1054            0 :                 ereport(ERROR,
    1055              :                         (errcode(ERRCODE_PROTOCOL_VIOLATION),
    1056              :                          errmsg("unsupported SCRAM channel-binding type \"%s\"",
    1057              :                                 sanitize_str(channel_binding_type))));
    1058            4 :             break;
    1059            0 :         default:
    1060            0 :             ereport(ERROR,
    1061              :                     (errcode(ERRCODE_PROTOCOL_VIOLATION),
    1062              :                      errmsg("malformed SCRAM message"),
    1063              :                      errdetail("Unexpected channel-binding flag \"%s\".",
    1064              :                                sanitize_char(*p))));
    1065              :     }
    1066              : 
    1067              :     /*
    1068              :      * Forbid optional authzid (authorization identity).  We don't support it.
    1069              :      */
    1070           59 :     if (*p == 'a')
    1071            0 :         ereport(ERROR,
    1072              :                 (errcode(ERRCODE_FEATURE_NOT_SUPPORTED),
    1073              :                  errmsg("client uses authorization identity, but it is not supported")));
    1074           59 :     if (*p != ',')
    1075            0 :         ereport(ERROR,
    1076              :                 (errcode(ERRCODE_PROTOCOL_VIOLATION),
    1077              :                  errmsg("malformed SCRAM message"),
    1078              :                  errdetail("Unexpected attribute \"%s\" in client-first-message.",
    1079              :                            sanitize_char(*p))));
    1080           59 :     p++;
    1081              : 
    1082           59 :     state->client_first_message_bare = pstrdup(p);
    1083              : 
    1084              :     /*
    1085              :      * Any mandatory extensions would go here.  We don't support any.
    1086              :      *
    1087              :      * RFC 5802 specifies error code "e=extensions-not-supported" for this,
    1088              :      * but it can only be sent in the server-final message.  We prefer to fail
    1089              :      * immediately (which the RFC also allows).
    1090              :      */
    1091           59 :     if (*p == 'm')
    1092            0 :         ereport(ERROR,
    1093              :                 (errcode(ERRCODE_FEATURE_NOT_SUPPORTED),
    1094              :                  errmsg("client requires an unsupported SCRAM extension")));
    1095              : 
    1096              :     /*
    1097              :      * Read username.  Note: this is ignored.  We use the username from the
    1098              :      * startup message instead, still it is kept around if provided as it
    1099              :      * proves to be useful for debugging purposes.
    1100              :      */
    1101           59 :     state->client_username = read_attr_value(&p, 'n');
    1102              : 
    1103              :     /* read nonce and check that it is made of only printable characters */
    1104           59 :     state->client_nonce = read_attr_value(&p, 'r');
    1105           59 :     if (!is_scram_printable(state->client_nonce))
    1106            0 :         ereport(ERROR,
    1107              :                 (errcode(ERRCODE_PROTOCOL_VIOLATION),
    1108              :                  errmsg("non-printable characters in SCRAM nonce")));
    1109              : 
    1110              :     /*
    1111              :      * There can be any number of optional extensions after this.  We don't
    1112              :      * support any extensions, so ignore them.
    1113              :      */
    1114           59 :     while (*p != '\0')
    1115            0 :         read_any_attr(&p, NULL);
    1116              : 
    1117              :     /* success! */
    1118           59 : }
    1119              : 
    1120              : /*
    1121              :  * Verify the final nonce contained in the last message received from
    1122              :  * client in an exchange.
    1123              :  */
    1124              : static bool
    1125           59 : verify_final_nonce(scram_state *state)
    1126              : {
    1127           59 :     int         client_nonce_len = strlen(state->client_nonce);
    1128           59 :     int         server_nonce_len = strlen(state->server_nonce);
    1129           59 :     int         final_nonce_len = strlen(state->client_final_nonce);
    1130              : 
    1131           59 :     if (final_nonce_len != client_nonce_len + server_nonce_len)
    1132            0 :         return false;
    1133           59 :     if (memcmp(state->client_final_nonce, state->client_nonce, client_nonce_len) != 0)
    1134            0 :         return false;
    1135           59 :     if (memcmp(state->client_final_nonce + client_nonce_len, state->server_nonce, server_nonce_len) != 0)
    1136            0 :         return false;
    1137              : 
    1138           59 :     return true;
    1139              : }
    1140              : 
    1141              : /*
    1142              :  * Verify the client proof contained in the last message received from
    1143              :  * client in an exchange.  Returns true if the verification is a success,
    1144              :  * or false for a failure.
    1145              :  */
    1146              : static bool
    1147           59 : verify_client_proof(scram_state *state)
    1148              : {
    1149              :     uint8       ClientSignature[SCRAM_MAX_KEY_LEN];
    1150              :     uint8       client_StoredKey[SCRAM_MAX_KEY_LEN];
    1151           59 :     pg_hmac_ctx *ctx = pg_hmac_create(state->hash_type);
    1152              :     int         i;
    1153           59 :     const char *errstr = NULL;
    1154              : 
    1155              :     /*
    1156              :      * Calculate ClientSignature.  Note that we don't log directly a failure
    1157              :      * here even when processing the calculations as this could involve a mock
    1158              :      * authentication.
    1159              :      */
    1160          118 :     if (pg_hmac_init(ctx, state->StoredKey, state->key_length) < 0 ||
    1161           59 :         pg_hmac_update(ctx,
    1162           59 :                        (uint8 *) state->client_first_message_bare,
    1163          118 :                        strlen(state->client_first_message_bare)) < 0 ||
    1164          118 :         pg_hmac_update(ctx, (uint8 *) ",", 1) < 0 ||
    1165           59 :         pg_hmac_update(ctx,
    1166           59 :                        (uint8 *) state->server_first_message,
    1167          118 :                        strlen(state->server_first_message)) < 0 ||
    1168          118 :         pg_hmac_update(ctx, (uint8 *) ",", 1) < 0 ||
    1169           59 :         pg_hmac_update(ctx,
    1170           59 :                        (uint8 *) state->client_final_message_without_proof,
    1171          118 :                        strlen(state->client_final_message_without_proof)) < 0 ||
    1172           59 :         pg_hmac_final(ctx, ClientSignature, state->key_length) < 0)
    1173              :     {
    1174            0 :         elog(ERROR, "could not calculate client signature: %s",
    1175              :              pg_hmac_error(ctx));
    1176              :     }
    1177              : 
    1178           59 :     pg_hmac_free(ctx);
    1179              : 
    1180              :     /* Extract the ClientKey that the client calculated from the proof */
    1181         1947 :     for (i = 0; i < state->key_length; i++)
    1182         1888 :         state->ClientKey[i] = state->ClientProof[i] ^ ClientSignature[i];
    1183              : 
    1184              :     /* Hash it one more time, and compare with StoredKey */
    1185           59 :     if (scram_H(state->ClientKey, state->hash_type, state->key_length,
    1186              :                 client_StoredKey, &errstr) < 0)
    1187            0 :         elog(ERROR, "could not hash stored key: %s", errstr);
    1188              : 
    1189           59 :     if (memcmp(client_StoredKey, state->StoredKey, state->key_length) != 0)
    1190            6 :         return false;
    1191              : 
    1192           53 :     return true;
    1193              : }
    1194              : 
    1195              : /*
    1196              :  * Build the first server-side message sent to the client in a SCRAM
    1197              :  * communication exchange.
    1198              :  */
    1199              : static char *
    1200           59 : build_server_first_message(scram_state *state)
    1201              : {
    1202              :     /*------
    1203              :      * The syntax for the server-first-message is: (RFC 5802)
    1204              :      *
    1205              :      * server-first-message =
    1206              :      *                   [reserved-mext ","] nonce "," salt ","
    1207              :      *                   iteration-count ["," extensions]
    1208              :      *
    1209              :      * nonce           = "r=" c-nonce [s-nonce]
    1210              :      *                   ;; Second part provided by server.
    1211              :      *
    1212              :      * c-nonce         = printable
    1213              :      *
    1214              :      * s-nonce         = printable
    1215              :      *
    1216              :      * salt            = "s=" base64
    1217              :      *
    1218              :      * iteration-count = "i=" posit-number
    1219              :      *                   ;; A positive number.
    1220              :      *
    1221              :      * Example:
    1222              :      *
    1223              :      * r=fyko+d2lbbFgONRv9qkxdawL3rfcNHYJY1ZVvWVs7j,s=QSXCR+Q6sek8bf92,i=4096
    1224              :      *------
    1225              :      */
    1226              : 
    1227              :     /*
    1228              :      * Per the spec, the nonce may consist of any printable ASCII characters.
    1229              :      * For convenience, however, we don't use the whole range available,
    1230              :      * rather, we generate some random bytes, and base64 encode them.
    1231              :      */
    1232              :     uint8       raw_nonce[SCRAM_RAW_NONCE_LEN];
    1233              :     int         encoded_len;
    1234              : 
    1235           59 :     if (!pg_strong_random(raw_nonce, SCRAM_RAW_NONCE_LEN))
    1236            0 :         ereport(ERROR,
    1237              :                 (errcode(ERRCODE_INTERNAL_ERROR),
    1238              :                  errmsg("could not generate random nonce")));
    1239              : 
    1240           59 :     encoded_len = pg_b64_enc_len(SCRAM_RAW_NONCE_LEN);
    1241              :     /* don't forget the zero-terminator */
    1242           59 :     state->server_nonce = palloc(encoded_len + 1);
    1243           59 :     encoded_len = pg_b64_encode(raw_nonce, SCRAM_RAW_NONCE_LEN,
    1244              :                                 state->server_nonce, encoded_len);
    1245           59 :     if (encoded_len < 0)
    1246            0 :         ereport(ERROR,
    1247              :                 (errcode(ERRCODE_INTERNAL_ERROR),
    1248              :                  errmsg("could not encode random nonce")));
    1249           59 :     state->server_nonce[encoded_len] = '\0';
    1250              : 
    1251           59 :     state->server_first_message =
    1252           59 :         psprintf("r=%s%s,s=%s,i=%d",
    1253              :                  state->client_nonce, state->server_nonce,
    1254              :                  state->salt, state->iterations);
    1255              : 
    1256           59 :     return pstrdup(state->server_first_message);
    1257              : }
    1258              : 
    1259              : 
    1260              : /*
    1261              :  * Read and parse the final message received from client.
    1262              :  */
    1263              : static void
    1264           59 : read_client_final_message(scram_state *state, const char *input)
    1265              : {
    1266              :     char        attr;
    1267              :     char       *channel_binding;
    1268              :     char       *value;
    1269              :     char       *begin,
    1270              :                *proof;
    1271              :     char       *p;
    1272              :     uint8      *client_proof;
    1273              :     int         client_proof_len;
    1274              : 
    1275           59 :     begin = p = pstrdup(input);
    1276              : 
    1277              :     /*------
    1278              :      * The syntax for the server-first-message is: (RFC 5802)
    1279              :      *
    1280              :      * gs2-header      = gs2-cbind-flag "," [ authzid ] ","
    1281              :      *                   ;; GS2 header for SCRAM
    1282              :      *                   ;; (the actual GS2 header includes an optional
    1283              :      *                   ;; flag to indicate that the GSS mechanism is not
    1284              :      *                   ;; "standard", but since SCRAM is "standard", we
    1285              :      *                   ;; don't include that flag).
    1286              :      *
    1287              :      * cbind-input   = gs2-header [ cbind-data ]
    1288              :      *                   ;; cbind-data MUST be present for
    1289              :      *                   ;; gs2-cbind-flag of "p" and MUST be absent
    1290              :      *                   ;; for "y" or "n".
    1291              :      *
    1292              :      * channel-binding = "c=" base64
    1293              :      *                   ;; base64 encoding of cbind-input.
    1294              :      *
    1295              :      * proof           = "p=" base64
    1296              :      *
    1297              :      * client-final-message-without-proof =
    1298              :      *                   channel-binding "," nonce [","
    1299              :      *                   extensions]
    1300              :      *
    1301              :      * client-final-message =
    1302              :      *                   client-final-message-without-proof "," proof
    1303              :      *------
    1304              :      */
    1305              : 
    1306              :     /*
    1307              :      * Read channel binding.  This repeats the channel-binding flags and is
    1308              :      * then followed by the actual binding data depending on the type.
    1309              :      */
    1310           59 :     channel_binding = read_attr_value(&p, 'c');
    1311           59 :     if (state->channel_binding_in_use)
    1312              :     {
    1313              : #ifdef USE_SSL
    1314            4 :         const char *cbind_data = NULL;
    1315            4 :         size_t      cbind_data_len = 0;
    1316              :         size_t      cbind_header_len;
    1317              :         char       *cbind_input;
    1318              :         size_t      cbind_input_len;
    1319              :         char       *b64_message;
    1320              :         int         b64_message_len;
    1321              : 
    1322              :         Assert(state->cbind_flag == 'p');
    1323              : 
    1324              :         /* Fetch hash data of server's SSL certificate */
    1325            4 :         cbind_data = be_tls_get_certificate_hash(state->port,
    1326              :                                                  &cbind_data_len);
    1327              : 
    1328              :         /* should not happen */
    1329            4 :         if (cbind_data == NULL || cbind_data_len == 0)
    1330            0 :             elog(ERROR, "could not get server certificate hash");
    1331              : 
    1332            4 :         cbind_header_len = strlen("p=tls-server-end-point,,");    /* p=type,, */
    1333            4 :         cbind_input_len = cbind_header_len + cbind_data_len;
    1334            4 :         cbind_input = palloc(cbind_input_len);
    1335            4 :         snprintf(cbind_input, cbind_input_len, "p=tls-server-end-point,,");
    1336            4 :         memcpy(cbind_input + cbind_header_len, cbind_data, cbind_data_len);
    1337              : 
    1338            4 :         b64_message_len = pg_b64_enc_len(cbind_input_len);
    1339              :         /* don't forget the zero-terminator */
    1340            4 :         b64_message = palloc(b64_message_len + 1);
    1341            4 :         b64_message_len = pg_b64_encode((uint8 *) cbind_input, cbind_input_len,
    1342              :                                         b64_message, b64_message_len);
    1343            4 :         if (b64_message_len < 0)
    1344            0 :             elog(ERROR, "could not encode channel binding data");
    1345            4 :         b64_message[b64_message_len] = '\0';
    1346              : 
    1347              :         /*
    1348              :          * Compare the value sent by the client with the value expected by the
    1349              :          * server.
    1350              :          */
    1351            4 :         if (strcmp(channel_binding, b64_message) != 0)
    1352            0 :             ereport(ERROR,
    1353              :                     (errcode(ERRCODE_INVALID_AUTHORIZATION_SPECIFICATION),
    1354              :                      errmsg("SCRAM channel binding check failed")));
    1355              : #else
    1356              :         /* shouldn't happen, because we checked this earlier already */
    1357              :         elog(ERROR, "channel binding not supported by this build");
    1358              : #endif
    1359              :     }
    1360              :     else
    1361              :     {
    1362              :         /*
    1363              :          * If we are not using channel binding, the binding data is expected
    1364              :          * to always be "biws", which is "n,," base64-encoded, or "eSws",
    1365              :          * which is "y,,".  We also have to check whether the flag is the same
    1366              :          * one that the client originally sent.
    1367              :          */
    1368           55 :         if (!(strcmp(channel_binding, "biws") == 0 && state->cbind_flag == 'n') &&
    1369            0 :             !(strcmp(channel_binding, "eSws") == 0 && state->cbind_flag == 'y'))
    1370            0 :             ereport(ERROR,
    1371              :                     (errcode(ERRCODE_PROTOCOL_VIOLATION),
    1372              :                      errmsg("unexpected SCRAM channel-binding attribute in client-final-message")));
    1373              :     }
    1374              : 
    1375           59 :     state->client_final_nonce = read_attr_value(&p, 'r');
    1376              : 
    1377              :     /* ignore optional extensions, read until we find "p" attribute */
    1378              :     do
    1379              :     {
    1380           59 :         proof = p - 1;
    1381           59 :         value = read_any_attr(&p, &attr);
    1382           59 :     } while (attr != 'p');
    1383              : 
    1384           59 :     client_proof_len = pg_b64_dec_len(strlen(value));
    1385           59 :     client_proof = palloc(client_proof_len);
    1386           59 :     if (pg_b64_decode(value, strlen(value), client_proof,
    1387           59 :                       client_proof_len) != state->key_length)
    1388            0 :         ereport(ERROR,
    1389              :                 (errcode(ERRCODE_PROTOCOL_VIOLATION),
    1390              :                  errmsg("malformed SCRAM message"),
    1391              :                  errdetail("Malformed proof in client-final-message.")));
    1392           59 :     memcpy(state->ClientProof, client_proof, state->key_length);
    1393           59 :     pfree(client_proof);
    1394              : 
    1395           59 :     if (*p != '\0')
    1396            0 :         ereport(ERROR,
    1397              :                 (errcode(ERRCODE_PROTOCOL_VIOLATION),
    1398              :                  errmsg("malformed SCRAM message"),
    1399              :                  errdetail("Garbage found at the end of client-final-message.")));
    1400              : 
    1401           59 :     state->client_final_message_without_proof = palloc(proof - begin + 1);
    1402           59 :     memcpy(state->client_final_message_without_proof, input, proof - begin);
    1403           59 :     state->client_final_message_without_proof[proof - begin] = '\0';
    1404           59 : }
    1405              : 
    1406              : /*
    1407              :  * Build the final server-side message of an exchange.
    1408              :  */
    1409              : static char *
    1410           53 : build_server_final_message(scram_state *state)
    1411              : {
    1412              :     uint8       ServerSignature[SCRAM_MAX_KEY_LEN];
    1413              :     char       *server_signature_base64;
    1414              :     int         siglen;
    1415           53 :     pg_hmac_ctx *ctx = pg_hmac_create(state->hash_type);
    1416              : 
    1417              :     /* calculate ServerSignature */
    1418          106 :     if (pg_hmac_init(ctx, state->ServerKey, state->key_length) < 0 ||
    1419           53 :         pg_hmac_update(ctx,
    1420           53 :                        (uint8 *) state->client_first_message_bare,
    1421          106 :                        strlen(state->client_first_message_bare)) < 0 ||
    1422          106 :         pg_hmac_update(ctx, (uint8 *) ",", 1) < 0 ||
    1423           53 :         pg_hmac_update(ctx,
    1424           53 :                        (uint8 *) state->server_first_message,
    1425          106 :                        strlen(state->server_first_message)) < 0 ||
    1426          106 :         pg_hmac_update(ctx, (uint8 *) ",", 1) < 0 ||
    1427           53 :         pg_hmac_update(ctx,
    1428           53 :                        (uint8 *) state->client_final_message_without_proof,
    1429          106 :                        strlen(state->client_final_message_without_proof)) < 0 ||
    1430           53 :         pg_hmac_final(ctx, ServerSignature, state->key_length) < 0)
    1431              :     {
    1432            0 :         elog(ERROR, "could not calculate server signature: %s",
    1433              :              pg_hmac_error(ctx));
    1434              :     }
    1435              : 
    1436           53 :     pg_hmac_free(ctx);
    1437              : 
    1438           53 :     siglen = pg_b64_enc_len(state->key_length);
    1439              :     /* don't forget the zero-terminator */
    1440           53 :     server_signature_base64 = palloc(siglen + 1);
    1441           53 :     siglen = pg_b64_encode(ServerSignature,
    1442              :                            state->key_length, server_signature_base64,
    1443              :                            siglen);
    1444           53 :     if (siglen < 0)
    1445            0 :         elog(ERROR, "could not encode server signature");
    1446           53 :     server_signature_base64[siglen] = '\0';
    1447              : 
    1448              :     /*------
    1449              :      * The syntax for the server-final-message is: (RFC 5802)
    1450              :      *
    1451              :      * verifier        = "v=" base64
    1452              :      *                   ;; base-64 encoded ServerSignature.
    1453              :      *
    1454              :      * server-final-message = (server-error / verifier)
    1455              :      *                   ["," extensions]
    1456              :      *
    1457              :      *------
    1458              :      */
    1459           53 :     return psprintf("v=%s", server_signature_base64);
    1460              : }
    1461              : 
    1462              : 
    1463              : /*
    1464              :  * Deterministically generate salt for mock authentication, using a SHA256
    1465              :  * hash based on the username and a cluster-level secret key.  Returns a
    1466              :  * pointer to a static buffer of size SCRAM_DEFAULT_SALT_LEN, or NULL.
    1467              :  */
    1468              : static uint8 *
    1469            2 : scram_mock_salt(const char *username, pg_cryptohash_type hash_type,
    1470              :                 int key_length)
    1471              : {
    1472              :     pg_cryptohash_ctx *ctx;
    1473              :     static uint8 sha_digest[SCRAM_MAX_KEY_LEN];
    1474            2 :     char       *mock_auth_nonce = GetMockAuthenticationNonce();
    1475              : 
    1476              :     /*
    1477              :      * Generate salt using a SHA256 hash of the username and the cluster's
    1478              :      * mock authentication nonce.  (This works as long as the salt length is
    1479              :      * not larger than the SHA256 digest length.  If the salt is smaller, the
    1480              :      * caller will just ignore the extra data.)
    1481              :      */
    1482              :     StaticAssertDecl(PG_SHA256_DIGEST_LENGTH >= SCRAM_DEFAULT_SALT_LEN,
    1483              :                      "salt length greater than SHA256 digest length");
    1484              : 
    1485              :     /*
    1486              :      * This may be worth refreshing if support for more hash methods is\
    1487              :      * added.
    1488              :      */
    1489              :     Assert(hash_type == PG_SHA256);
    1490              : 
    1491            2 :     ctx = pg_cryptohash_create(hash_type);
    1492            4 :     if (pg_cryptohash_init(ctx) < 0 ||
    1493            4 :         pg_cryptohash_update(ctx, (const uint8 *) username, strlen(username)) < 0 ||
    1494            4 :         pg_cryptohash_update(ctx, (const uint8 *) mock_auth_nonce, MOCK_AUTH_NONCE_LEN) < 0 ||
    1495            2 :         pg_cryptohash_final(ctx, sha_digest, key_length) < 0)
    1496              :     {
    1497            0 :         pg_cryptohash_free(ctx);
    1498            0 :         return NULL;
    1499              :     }
    1500            2 :     pg_cryptohash_free(ctx);
    1501              : 
    1502            2 :     return sha_digest;
    1503              : }
        

Generated by: LCOV version 2.0-1