LCOV - code coverage report
Current view: top level - src/interfaces/libpq - fe-auth-scram.c (source / functions) Hit Total Coverage
Test: PostgreSQL 18devel Lines: 256 382 67.0 %
Date: 2025-02-21 18:14:53 Functions: 12 12 100.0 %
Legend: Lines: hit not hit

          Line data    Source code
       1             : /*-------------------------------------------------------------------------
       2             :  *
       3             :  * fe-auth-scram.c
       4             :  *     The front-end (client) implementation of SCRAM authentication.
       5             :  *
       6             :  * Portions Copyright (c) 1996-2025, PostgreSQL Global Development Group
       7             :  * Portions Copyright (c) 1994, Regents of the University of California
       8             :  *
       9             :  * IDENTIFICATION
      10             :  *    src/interfaces/libpq/fe-auth-scram.c
      11             :  *
      12             :  *-------------------------------------------------------------------------
      13             :  */
      14             : 
      15             : #include "postgres_fe.h"
      16             : 
      17             : #include "common/base64.h"
      18             : #include "common/hmac.h"
      19             : #include "common/saslprep.h"
      20             : #include "common/scram-common.h"
      21             : #include "fe-auth.h"
      22             : 
      23             : 
      24             : /* The exported SCRAM callback mechanism. */
      25             : static void *scram_init(PGconn *conn, const char *password,
      26             :                         const char *sasl_mechanism);
      27             : static SASLStatus scram_exchange(void *opaq, bool final,
      28             :                                  char *input, int inputlen,
      29             :                                  char **output, int *outputlen);
      30             : static bool scram_channel_bound(void *opaq);
      31             : static void scram_free(void *opaq);
      32             : 
      33             : const pg_fe_sasl_mech pg_scram_mech = {
      34             :     scram_init,
      35             :     scram_exchange,
      36             :     scram_channel_bound,
      37             :     scram_free
      38             : };
      39             : 
      40             : /*
      41             :  * Status of exchange messages used for SCRAM authentication via the
      42             :  * SASL protocol.
      43             :  */
      44             : typedef enum
      45             : {
      46             :     FE_SCRAM_INIT,
      47             :     FE_SCRAM_NONCE_SENT,
      48             :     FE_SCRAM_PROOF_SENT,
      49             :     FE_SCRAM_FINISHED,
      50             : } fe_scram_state_enum;
      51             : 
      52             : typedef struct
      53             : {
      54             :     fe_scram_state_enum state;
      55             : 
      56             :     /* These are supplied by the user */
      57             :     PGconn     *conn;
      58             :     char       *password;
      59             :     char       *sasl_mechanism;
      60             : 
      61             :     /* State data depending on the hash type */
      62             :     pg_cryptohash_type hash_type;
      63             :     int         key_length;
      64             : 
      65             :     /* We construct these */
      66             :     uint8       SaltedPassword[SCRAM_MAX_KEY_LEN];
      67             :     char       *client_nonce;
      68             :     char       *client_first_message_bare;
      69             :     char       *client_final_message_without_proof;
      70             : 
      71             :     /* These come from the server-first message */
      72             :     char       *server_first_message;
      73             :     char       *salt;
      74             :     int         saltlen;
      75             :     int         iterations;
      76             :     char       *nonce;
      77             : 
      78             :     /* These come from the server-final message */
      79             :     char       *server_final_message;
      80             :     char        ServerSignature[SCRAM_MAX_KEY_LEN];
      81             : } fe_scram_state;
      82             : 
      83             : static bool read_server_first_message(fe_scram_state *state, char *input);
      84             : static bool read_server_final_message(fe_scram_state *state, char *input);
      85             : static char *build_client_first_message(fe_scram_state *state);
      86             : static char *build_client_final_message(fe_scram_state *state);
      87             : static bool verify_server_signature(fe_scram_state *state, bool *match,
      88             :                                     const char **errstr);
      89             : static bool calculate_client_proof(fe_scram_state *state,
      90             :                                    const char *client_final_message_without_proof,
      91             :                                    uint8 *result, const char **errstr);
      92             : 
      93             : /*
      94             :  * Initialize SCRAM exchange status.
      95             :  */
      96             : static void *
      97          98 : scram_init(PGconn *conn,
      98             :            const char *password,
      99             :            const char *sasl_mechanism)
     100             : {
     101             :     fe_scram_state *state;
     102             :     char       *prep_password;
     103             :     pg_saslprep_rc rc;
     104             : 
     105             :     Assert(sasl_mechanism != NULL);
     106             : 
     107          98 :     state = (fe_scram_state *) malloc(sizeof(fe_scram_state));
     108          98 :     if (!state)
     109           0 :         return NULL;
     110          98 :     memset(state, 0, sizeof(fe_scram_state));
     111          98 :     state->conn = conn;
     112          98 :     state->state = FE_SCRAM_INIT;
     113          98 :     state->key_length = SCRAM_SHA_256_KEY_LEN;
     114          98 :     state->hash_type = PG_SHA256;
     115             : 
     116          98 :     state->sasl_mechanism = strdup(sasl_mechanism);
     117          98 :     if (!state->sasl_mechanism)
     118             :     {
     119           0 :         free(state);
     120           0 :         return NULL;
     121             :     }
     122             : 
     123          98 :     if (password)
     124             :     {
     125             :         /* Normalize the password with SASLprep, if possible */
     126          90 :         rc = pg_saslprep(password, &prep_password);
     127          90 :         if (rc == SASLPREP_OOM)
     128             :         {
     129           0 :             free(state->sasl_mechanism);
     130           0 :             free(state);
     131           0 :             return NULL;
     132             :         }
     133          90 :         if (rc != SASLPREP_SUCCESS)
     134             :         {
     135           4 :             prep_password = strdup(password);
     136           4 :             if (!prep_password)
     137             :             {
     138           0 :                 free(state->sasl_mechanism);
     139           0 :                 free(state);
     140           0 :                 return NULL;
     141             :             }
     142             :         }
     143          90 :         state->password = prep_password;
     144             :     }
     145             : 
     146          98 :     return state;
     147             : }
     148             : 
     149             : /*
     150             :  * Return true if channel binding was employed and the SCRAM exchange
     151             :  * completed. This should be used after a successful exchange to determine
     152             :  * whether the server authenticated itself to the client.
     153             :  *
     154             :  * Note that the caller must also ensure that the exchange was actually
     155             :  * successful.
     156             :  */
     157             : static bool
     158           6 : scram_channel_bound(void *opaq)
     159             : {
     160           6 :     fe_scram_state *state = (fe_scram_state *) opaq;
     161             : 
     162             :     /* no SCRAM exchange done */
     163           6 :     if (state == NULL)
     164           0 :         return false;
     165             : 
     166             :     /* SCRAM exchange not completed */
     167           6 :     if (state->state != FE_SCRAM_FINISHED)
     168           0 :         return false;
     169             : 
     170             :     /* channel binding mechanism not used */
     171           6 :     if (strcmp(state->sasl_mechanism, SCRAM_SHA_256_PLUS_NAME) != 0)
     172           0 :         return false;
     173             : 
     174             :     /* all clear! */
     175           6 :     return true;
     176             : }
     177             : 
     178             : /*
     179             :  * Free SCRAM exchange status
     180             :  */
     181             : static void
     182          90 : scram_free(void *opaq)
     183             : {
     184          90 :     fe_scram_state *state = (fe_scram_state *) opaq;
     185             : 
     186          90 :     free(state->password);
     187          90 :     free(state->sasl_mechanism);
     188             : 
     189             :     /* client messages */
     190          90 :     free(state->client_nonce);
     191          90 :     free(state->client_first_message_bare);
     192          90 :     free(state->client_final_message_without_proof);
     193             : 
     194             :     /* first message from server */
     195          90 :     free(state->server_first_message);
     196          90 :     free(state->salt);
     197          90 :     free(state->nonce);
     198             : 
     199             :     /* final message from server */
     200          90 :     free(state->server_final_message);
     201             : 
     202          90 :     free(state);
     203          90 : }
     204             : 
     205             : /*
     206             :  * Exchange a SCRAM message with backend.
     207             :  */
     208             : static SASLStatus
     209         282 : scram_exchange(void *opaq, bool final,
     210             :                char *input, int inputlen,
     211             :                char **output, int *outputlen)
     212             : {
     213         282 :     fe_scram_state *state = (fe_scram_state *) opaq;
     214         282 :     PGconn     *conn = state->conn;
     215         282 :     const char *errstr = NULL;
     216             : 
     217         282 :     *output = NULL;
     218         282 :     *outputlen = 0;
     219             : 
     220             :     /*
     221             :      * Check that the input length agrees with the string length of the input.
     222             :      * We can ignore inputlen after this.
     223             :      */
     224         282 :     if (state->state != FE_SCRAM_INIT)
     225             :     {
     226         184 :         if (inputlen == 0)
     227             :         {
     228           0 :             libpq_append_conn_error(conn, "malformed SCRAM message (empty message)");
     229           0 :             return SASL_FAILED;
     230             :         }
     231         184 :         if (inputlen != strlen(input))
     232             :         {
     233           0 :             libpq_append_conn_error(conn, "malformed SCRAM message (length mismatch)");
     234           0 :             return SASL_FAILED;
     235             :         }
     236             :     }
     237             : 
     238         282 :     switch (state->state)
     239             :     {
     240          98 :         case FE_SCRAM_INIT:
     241             :             /* Begin the SCRAM handshake, by sending client nonce */
     242          98 :             *output = build_client_first_message(state);
     243          98 :             if (*output == NULL)
     244           0 :                 return SASL_FAILED;
     245             : 
     246          98 :             *outputlen = strlen(*output);
     247          98 :             state->state = FE_SCRAM_NONCE_SENT;
     248          98 :             return SASL_CONTINUE;
     249             : 
     250          98 :         case FE_SCRAM_NONCE_SENT:
     251             :             /* Receive salt and server nonce, send response. */
     252          98 :             if (!read_server_first_message(state, input))
     253           0 :                 return SASL_FAILED;
     254             : 
     255          98 :             *output = build_client_final_message(state);
     256          98 :             if (*output == NULL)
     257           0 :                 return SASL_FAILED;
     258             : 
     259          98 :             *outputlen = strlen(*output);
     260          98 :             state->state = FE_SCRAM_PROOF_SENT;
     261          98 :             return SASL_CONTINUE;
     262             : 
     263          86 :         case FE_SCRAM_PROOF_SENT:
     264             :             {
     265             :                 bool        match;
     266             : 
     267             :                 /* Receive server signature */
     268          86 :                 if (!read_server_final_message(state, input))
     269           0 :                     return SASL_FAILED;
     270             : 
     271             :                 /*
     272             :                  * Verify server signature, to make sure we're talking to the
     273             :                  * genuine server.
     274             :                  */
     275          86 :                 if (!verify_server_signature(state, &match, &errstr))
     276             :                 {
     277           0 :                     libpq_append_conn_error(conn, "could not verify server signature: %s", errstr);
     278           0 :                     return SASL_FAILED;
     279             :                 }
     280             : 
     281          86 :                 if (!match)
     282             :                 {
     283           0 :                     libpq_append_conn_error(conn, "incorrect server signature");
     284             :                 }
     285          86 :                 state->state = FE_SCRAM_FINISHED;
     286          86 :                 state->conn->client_finished_auth = true;
     287          86 :                 return match ? SASL_COMPLETE : SASL_FAILED;
     288             :             }
     289             : 
     290           0 :         default:
     291             :             /* shouldn't happen */
     292           0 :             libpq_append_conn_error(conn, "invalid SCRAM exchange state");
     293           0 :             break;
     294             :     }
     295             : 
     296           0 :     return SASL_FAILED;
     297             : }
     298             : 
     299             : /*
     300             :  * Read value for an attribute part of a SCRAM message.
     301             :  *
     302             :  * The buffer at **input is destructively modified, and *input is
     303             :  * advanced over the "attr=value" string and any following comma.
     304             :  *
     305             :  * On failure, append an error message to *errorMessage and return NULL.
     306             :  */
     307             : static char *
     308         380 : read_attr_value(char **input, char attr, PQExpBuffer errorMessage)
     309             : {
     310         380 :     char       *begin = *input;
     311             :     char       *end;
     312             : 
     313         380 :     if (*begin != attr)
     314             :     {
     315           0 :         libpq_append_error(errorMessage,
     316             :                            "malformed SCRAM message (attribute \"%c\" expected)",
     317             :                            attr);
     318           0 :         return NULL;
     319             :     }
     320         380 :     begin++;
     321             : 
     322         380 :     if (*begin != '=')
     323             :     {
     324           0 :         libpq_append_error(errorMessage,
     325             :                            "malformed SCRAM message (expected character \"=\" for attribute \"%c\")",
     326             :                            attr);
     327           0 :         return NULL;
     328             :     }
     329         380 :     begin++;
     330             : 
     331         380 :     end = begin;
     332       11608 :     while (*end && *end != ',')
     333       11228 :         end++;
     334             : 
     335         380 :     if (*end)
     336             :     {
     337         196 :         *end = '\0';
     338         196 :         *input = end + 1;
     339             :     }
     340             :     else
     341         184 :         *input = end;
     342             : 
     343         380 :     return begin;
     344             : }
     345             : 
     346             : /*
     347             :  * Build the first exchange message sent by the client.
     348             :  */
     349             : static char *
     350          98 : build_client_first_message(fe_scram_state *state)
     351             : {
     352          98 :     PGconn     *conn = state->conn;
     353             :     char        raw_nonce[SCRAM_RAW_NONCE_LEN + 1];
     354             :     char       *result;
     355             :     int         channel_info_len;
     356             :     int         encoded_len;
     357             :     PQExpBufferData buf;
     358             : 
     359             :     /*
     360             :      * Generate a "raw" nonce.  This is converted to ASCII-printable form by
     361             :      * base64-encoding it.
     362             :      */
     363          98 :     if (!pg_strong_random(raw_nonce, SCRAM_RAW_NONCE_LEN))
     364             :     {
     365           0 :         libpq_append_conn_error(conn, "could not generate nonce");
     366           0 :         return NULL;
     367             :     }
     368             : 
     369          98 :     encoded_len = pg_b64_enc_len(SCRAM_RAW_NONCE_LEN);
     370             :     /* don't forget the zero-terminator */
     371          98 :     state->client_nonce = malloc(encoded_len + 1);
     372          98 :     if (state->client_nonce == NULL)
     373             :     {
     374           0 :         libpq_append_conn_error(conn, "out of memory");
     375           0 :         return NULL;
     376             :     }
     377          98 :     encoded_len = pg_b64_encode(raw_nonce, SCRAM_RAW_NONCE_LEN,
     378             :                                 state->client_nonce, encoded_len);
     379          98 :     if (encoded_len < 0)
     380             :     {
     381           0 :         libpq_append_conn_error(conn, "could not encode nonce");
     382           0 :         return NULL;
     383             :     }
     384          98 :     state->client_nonce[encoded_len] = '\0';
     385             : 
     386             :     /*
     387             :      * Generate message.  The username is left empty as the backend uses the
     388             :      * value provided by the startup packet.  Also, as this username is not
     389             :      * prepared with SASLprep, the message parsing would fail if it includes
     390             :      * '=' or ',' characters.
     391             :      */
     392             : 
     393          98 :     initPQExpBuffer(&buf);
     394             : 
     395             :     /*
     396             :      * First build the gs2-header with channel binding information.
     397             :      */
     398          98 :     if (strcmp(state->sasl_mechanism, SCRAM_SHA_256_PLUS_NAME) == 0)
     399             :     {
     400             :         Assert(conn->ssl_in_use);
     401          10 :         appendPQExpBufferStr(&buf, "p=tls-server-end-point");
     402             :     }
     403             : #ifdef USE_SSL
     404          88 :     else if (conn->channel_binding[0] != 'd' && /* disable */
     405          84 :              conn->ssl_in_use)
     406             :     {
     407             :         /*
     408             :          * Client supports channel binding, but thinks the server does not.
     409             :          */
     410           0 :         appendPQExpBufferChar(&buf, 'y');
     411             :     }
     412             : #endif
     413             :     else
     414             :     {
     415             :         /*
     416             :          * Client does not support channel binding, or has disabled it.
     417             :          */
     418          88 :         appendPQExpBufferChar(&buf, 'n');
     419             :     }
     420             : 
     421          98 :     if (PQExpBufferDataBroken(buf))
     422           0 :         goto oom_error;
     423             : 
     424          98 :     channel_info_len = buf.len;
     425             : 
     426          98 :     appendPQExpBuffer(&buf, ",,n=,r=%s", state->client_nonce);
     427          98 :     if (PQExpBufferDataBroken(buf))
     428           0 :         goto oom_error;
     429             : 
     430             :     /*
     431             :      * The first message content needs to be saved without channel binding
     432             :      * information.
     433             :      */
     434          98 :     state->client_first_message_bare = strdup(buf.data + channel_info_len + 2);
     435          98 :     if (!state->client_first_message_bare)
     436           0 :         goto oom_error;
     437             : 
     438          98 :     result = strdup(buf.data);
     439          98 :     if (result == NULL)
     440           0 :         goto oom_error;
     441             : 
     442          98 :     termPQExpBuffer(&buf);
     443          98 :     return result;
     444             : 
     445           0 : oom_error:
     446           0 :     termPQExpBuffer(&buf);
     447           0 :     libpq_append_conn_error(conn, "out of memory");
     448           0 :     return NULL;
     449             : }
     450             : 
     451             : /*
     452             :  * Build the final exchange message sent from the client.
     453             :  */
     454             : static char *
     455          98 : build_client_final_message(fe_scram_state *state)
     456             : {
     457             :     PQExpBufferData buf;
     458          98 :     PGconn     *conn = state->conn;
     459             :     uint8       client_proof[SCRAM_MAX_KEY_LEN];
     460             :     char       *result;
     461             :     int         encoded_len;
     462          98 :     const char *errstr = NULL;
     463             : 
     464          98 :     initPQExpBuffer(&buf);
     465             : 
     466             :     /*
     467             :      * Construct client-final-message-without-proof.  We need to remember it
     468             :      * for verifying the server proof in the final step of authentication.
     469             :      *
     470             :      * The channel binding flag handling (p/y/n) must be consistent with
     471             :      * build_client_first_message(), because the server will check that it's
     472             :      * the same flag both times.
     473             :      */
     474          98 :     if (strcmp(state->sasl_mechanism, SCRAM_SHA_256_PLUS_NAME) == 0)
     475             :     {
     476             : #ifdef USE_SSL
     477          10 :         char       *cbind_data = NULL;
     478          10 :         size_t      cbind_data_len = 0;
     479             :         size_t      cbind_header_len;
     480             :         char       *cbind_input;
     481             :         size_t      cbind_input_len;
     482             :         int         encoded_cbind_len;
     483             : 
     484             :         /* Fetch hash data of server's SSL certificate */
     485             :         cbind_data =
     486          10 :             pgtls_get_peer_certificate_hash(state->conn,
     487             :                                             &cbind_data_len);
     488          10 :         if (cbind_data == NULL)
     489             :         {
     490             :             /* error message is already set on error */
     491           0 :             termPQExpBuffer(&buf);
     492           0 :             return NULL;
     493             :         }
     494             : 
     495          10 :         appendPQExpBufferStr(&buf, "c=");
     496             : 
     497             :         /* p=type,, */
     498          10 :         cbind_header_len = strlen("p=tls-server-end-point,,");
     499          10 :         cbind_input_len = cbind_header_len + cbind_data_len;
     500          10 :         cbind_input = malloc(cbind_input_len);
     501          10 :         if (!cbind_input)
     502             :         {
     503           0 :             free(cbind_data);
     504           0 :             goto oom_error;
     505             :         }
     506          10 :         memcpy(cbind_input, "p=tls-server-end-point,,", cbind_header_len);
     507          10 :         memcpy(cbind_input + cbind_header_len, cbind_data, cbind_data_len);
     508             : 
     509          10 :         encoded_cbind_len = pg_b64_enc_len(cbind_input_len);
     510          10 :         if (!enlargePQExpBuffer(&buf, encoded_cbind_len))
     511             :         {
     512           0 :             free(cbind_data);
     513           0 :             free(cbind_input);
     514           0 :             goto oom_error;
     515             :         }
     516          10 :         encoded_cbind_len = pg_b64_encode(cbind_input, cbind_input_len,
     517          10 :                                           buf.data + buf.len,
     518             :                                           encoded_cbind_len);
     519          10 :         if (encoded_cbind_len < 0)
     520             :         {
     521           0 :             free(cbind_data);
     522           0 :             free(cbind_input);
     523           0 :             termPQExpBuffer(&buf);
     524           0 :             appendPQExpBufferStr(&conn->errorMessage,
     525             :                                  "could not encode cbind data for channel binding\n");
     526           0 :             return NULL;
     527             :         }
     528          10 :         buf.len += encoded_cbind_len;
     529          10 :         buf.data[buf.len] = '\0';
     530             : 
     531          10 :         free(cbind_data);
     532          10 :         free(cbind_input);
     533             : #else
     534             :         /*
     535             :          * Chose channel binding, but the SSL library doesn't support it.
     536             :          * Shouldn't happen.
     537             :          */
     538             :         termPQExpBuffer(&buf);
     539             :         appendPQExpBufferStr(&conn->errorMessage,
     540             :                              "channel binding not supported by this build\n");
     541             :         return NULL;
     542             : #endif                          /* USE_SSL */
     543             :     }
     544             : #ifdef USE_SSL
     545          88 :     else if (conn->channel_binding[0] != 'd' && /* disable */
     546          84 :              conn->ssl_in_use)
     547           0 :         appendPQExpBufferStr(&buf, "c=eSws"); /* base64 of "y,," */
     548             : #endif
     549             :     else
     550          88 :         appendPQExpBufferStr(&buf, "c=biws"); /* base64 of "n,," */
     551             : 
     552          98 :     if (PQExpBufferDataBroken(buf))
     553           0 :         goto oom_error;
     554             : 
     555          98 :     appendPQExpBuffer(&buf, ",r=%s", state->nonce);
     556          98 :     if (PQExpBufferDataBroken(buf))
     557           0 :         goto oom_error;
     558             : 
     559          98 :     state->client_final_message_without_proof = strdup(buf.data);
     560          98 :     if (state->client_final_message_without_proof == NULL)
     561           0 :         goto oom_error;
     562             : 
     563             :     /* Append proof to it, to form client-final-message. */
     564          98 :     if (!calculate_client_proof(state,
     565          98 :                                 state->client_final_message_without_proof,
     566             :                                 client_proof, &errstr))
     567             :     {
     568           0 :         termPQExpBuffer(&buf);
     569           0 :         libpq_append_conn_error(conn, "could not calculate client proof: %s", errstr);
     570           0 :         return NULL;
     571             :     }
     572             : 
     573          98 :     appendPQExpBufferStr(&buf, ",p=");
     574          98 :     encoded_len = pg_b64_enc_len(state->key_length);
     575          98 :     if (!enlargePQExpBuffer(&buf, encoded_len))
     576           0 :         goto oom_error;
     577          98 :     encoded_len = pg_b64_encode((char *) client_proof,
     578             :                                 state->key_length,
     579          98 :                                 buf.data + buf.len,
     580             :                                 encoded_len);
     581          98 :     if (encoded_len < 0)
     582             :     {
     583           0 :         termPQExpBuffer(&buf);
     584           0 :         libpq_append_conn_error(conn, "could not encode client proof");
     585           0 :         return NULL;
     586             :     }
     587          98 :     buf.len += encoded_len;
     588          98 :     buf.data[buf.len] = '\0';
     589             : 
     590          98 :     result = strdup(buf.data);
     591          98 :     if (result == NULL)
     592           0 :         goto oom_error;
     593             : 
     594          98 :     termPQExpBuffer(&buf);
     595          98 :     return result;
     596             : 
     597           0 : oom_error:
     598           0 :     termPQExpBuffer(&buf);
     599           0 :     libpq_append_conn_error(conn, "out of memory");
     600           0 :     return NULL;
     601             : }
     602             : 
     603             : /*
     604             :  * Read the first exchange message coming from the server.
     605             :  */
     606             : static bool
     607          98 : read_server_first_message(fe_scram_state *state, char *input)
     608             : {
     609          98 :     PGconn     *conn = state->conn;
     610             :     char       *iterations_str;
     611             :     char       *endptr;
     612             :     char       *encoded_salt;
     613             :     char       *nonce;
     614             :     int         decoded_salt_len;
     615             : 
     616          98 :     state->server_first_message = strdup(input);
     617          98 :     if (state->server_first_message == NULL)
     618             :     {
     619           0 :         libpq_append_conn_error(conn, "out of memory");
     620           0 :         return false;
     621             :     }
     622             : 
     623             :     /* parse the message */
     624          98 :     nonce = read_attr_value(&input, 'r',
     625             :                             &conn->errorMessage);
     626          98 :     if (nonce == NULL)
     627             :     {
     628             :         /* read_attr_value() has appended an error string */
     629           0 :         return false;
     630             :     }
     631             : 
     632             :     /* Verify immediately that the server used our part of the nonce */
     633          98 :     if (strlen(nonce) < strlen(state->client_nonce) ||
     634          98 :         memcmp(nonce, state->client_nonce, strlen(state->client_nonce)) != 0)
     635             :     {
     636           0 :         libpq_append_conn_error(conn, "invalid SCRAM response (nonce mismatch)");
     637           0 :         return false;
     638             :     }
     639             : 
     640          98 :     state->nonce = strdup(nonce);
     641          98 :     if (state->nonce == NULL)
     642             :     {
     643           0 :         libpq_append_conn_error(conn, "out of memory");
     644           0 :         return false;
     645             :     }
     646             : 
     647          98 :     encoded_salt = read_attr_value(&input, 's', &conn->errorMessage);
     648          98 :     if (encoded_salt == NULL)
     649             :     {
     650             :         /* read_attr_value() has appended an error string */
     651           0 :         return false;
     652             :     }
     653          98 :     decoded_salt_len = pg_b64_dec_len(strlen(encoded_salt));
     654          98 :     state->salt = malloc(decoded_salt_len);
     655          98 :     if (state->salt == NULL)
     656             :     {
     657           0 :         libpq_append_conn_error(conn, "out of memory");
     658           0 :         return false;
     659             :     }
     660         196 :     state->saltlen = pg_b64_decode(encoded_salt,
     661          98 :                                    strlen(encoded_salt),
     662             :                                    state->salt,
     663             :                                    decoded_salt_len);
     664          98 :     if (state->saltlen < 0)
     665             :     {
     666           0 :         libpq_append_conn_error(conn, "malformed SCRAM message (invalid salt)");
     667           0 :         return false;
     668             :     }
     669             : 
     670          98 :     iterations_str = read_attr_value(&input, 'i', &conn->errorMessage);
     671          98 :     if (iterations_str == NULL)
     672             :     {
     673             :         /* read_attr_value() has appended an error string */
     674           0 :         return false;
     675             :     }
     676          98 :     state->iterations = strtol(iterations_str, &endptr, 10);
     677          98 :     if (*endptr != '\0' || state->iterations < 1)
     678             :     {
     679           0 :         libpq_append_conn_error(conn, "malformed SCRAM message (invalid iteration count)");
     680           0 :         return false;
     681             :     }
     682             : 
     683          98 :     if (*input != '\0')
     684           0 :         libpq_append_conn_error(conn, "malformed SCRAM message (garbage at end of server-first-message)");
     685             : 
     686          98 :     return true;
     687             : }
     688             : 
     689             : /*
     690             :  * Read the final exchange message coming from the server.
     691             :  */
     692             : static bool
     693          86 : read_server_final_message(fe_scram_state *state, char *input)
     694             : {
     695          86 :     PGconn     *conn = state->conn;
     696             :     char       *encoded_server_signature;
     697             :     char       *decoded_server_signature;
     698             :     int         server_signature_len;
     699             : 
     700          86 :     state->server_final_message = strdup(input);
     701          86 :     if (!state->server_final_message)
     702             :     {
     703           0 :         libpq_append_conn_error(conn, "out of memory");
     704           0 :         return false;
     705             :     }
     706             : 
     707             :     /* Check for error result. */
     708          86 :     if (*input == 'e')
     709             :     {
     710           0 :         char       *errmsg = read_attr_value(&input, 'e',
     711             :                                              &conn->errorMessage);
     712             : 
     713           0 :         if (errmsg == NULL)
     714             :         {
     715             :             /* read_attr_value() has appended an error message */
     716           0 :             return false;
     717             :         }
     718           0 :         libpq_append_conn_error(conn, "error received from server in SCRAM exchange: %s",
     719             :                                 errmsg);
     720           0 :         return false;
     721             :     }
     722             : 
     723             :     /* Parse the message. */
     724          86 :     encoded_server_signature = read_attr_value(&input, 'v',
     725             :                                                &conn->errorMessage);
     726          86 :     if (encoded_server_signature == NULL)
     727             :     {
     728             :         /* read_attr_value() has appended an error message */
     729           0 :         return false;
     730             :     }
     731             : 
     732          86 :     if (*input != '\0')
     733           0 :         libpq_append_conn_error(conn, "malformed SCRAM message (garbage at end of server-final-message)");
     734             : 
     735          86 :     server_signature_len = pg_b64_dec_len(strlen(encoded_server_signature));
     736          86 :     decoded_server_signature = malloc(server_signature_len);
     737          86 :     if (!decoded_server_signature)
     738             :     {
     739           0 :         libpq_append_conn_error(conn, "out of memory");
     740           0 :         return false;
     741             :     }
     742             : 
     743          86 :     server_signature_len = pg_b64_decode(encoded_server_signature,
     744          86 :                                          strlen(encoded_server_signature),
     745             :                                          decoded_server_signature,
     746             :                                          server_signature_len);
     747          86 :     if (server_signature_len != state->key_length)
     748             :     {
     749           0 :         free(decoded_server_signature);
     750           0 :         libpq_append_conn_error(conn, "malformed SCRAM message (invalid server signature)");
     751           0 :         return false;
     752             :     }
     753          86 :     memcpy(state->ServerSignature, decoded_server_signature,
     754          86 :            state->key_length);
     755          86 :     free(decoded_server_signature);
     756             : 
     757          86 :     return true;
     758             : }
     759             : 
     760             : /*
     761             :  * Calculate the client proof, part of the final exchange message sent
     762             :  * by the client.  Returns true on success, false on failure with *errstr
     763             :  * pointing to a message about the error details.
     764             :  */
     765             : static bool
     766          98 : calculate_client_proof(fe_scram_state *state,
     767             :                        const char *client_final_message_without_proof,
     768             :                        uint8 *result, const char **errstr)
     769             : {
     770             :     uint8       StoredKey[SCRAM_MAX_KEY_LEN];
     771             :     uint8       ClientKey[SCRAM_MAX_KEY_LEN];
     772             :     uint8       ClientSignature[SCRAM_MAX_KEY_LEN];
     773             :     int         i;
     774             :     pg_hmac_ctx *ctx;
     775             : 
     776          98 :     ctx = pg_hmac_create(state->hash_type);
     777          98 :     if (ctx == NULL)
     778             :     {
     779           0 :         *errstr = pg_hmac_error(NULL);  /* returns OOM */
     780           0 :         return false;
     781             :     }
     782             : 
     783          98 :     if (state->conn->scram_client_key_binary)
     784             :     {
     785           8 :         memcpy(ClientKey, state->conn->scram_client_key_binary, SCRAM_MAX_KEY_LEN);
     786             :     }
     787             :     else
     788             :     {
     789             :         /*
     790             :          * Calculate SaltedPassword, and store it in 'state' so that we can
     791             :          * reuse it later in verify_server_signature.
     792             :          */
     793          90 :         if (scram_SaltedPassword(state->password, state->hash_type,
     794          90 :                                  state->key_length, state->salt, state->saltlen,
     795          90 :                                  state->iterations, state->SaltedPassword,
     796          90 :                                  errstr) < 0 ||
     797          90 :             scram_ClientKey(state->SaltedPassword, state->hash_type,
     798             :                             state->key_length, ClientKey, errstr) < 0)
     799             :         {
     800             :             /* errstr is already filled here */
     801           0 :             pg_hmac_free(ctx);
     802           0 :             return false;
     803             :         }
     804             :     }
     805             : 
     806          98 :     if (scram_H(ClientKey, state->hash_type, state->key_length, StoredKey, errstr) < 0)
     807             :     {
     808           0 :         pg_hmac_free(ctx);
     809           0 :         return false;
     810             :     }
     811             : 
     812         196 :     if (pg_hmac_init(ctx, StoredKey, state->key_length) < 0 ||
     813          98 :         pg_hmac_update(ctx,
     814          98 :                        (uint8 *) state->client_first_message_bare,
     815         196 :                        strlen(state->client_first_message_bare)) < 0 ||
     816         196 :         pg_hmac_update(ctx, (uint8 *) ",", 1) < 0 ||
     817          98 :         pg_hmac_update(ctx,
     818          98 :                        (uint8 *) state->server_first_message,
     819         196 :                        strlen(state->server_first_message)) < 0 ||
     820         196 :         pg_hmac_update(ctx, (uint8 *) ",", 1) < 0 ||
     821          98 :         pg_hmac_update(ctx,
     822             :                        (uint8 *) client_final_message_without_proof,
     823          98 :                        strlen(client_final_message_without_proof)) < 0 ||
     824          98 :         pg_hmac_final(ctx, ClientSignature, state->key_length) < 0)
     825             :     {
     826           0 :         *errstr = pg_hmac_error(ctx);
     827           0 :         pg_hmac_free(ctx);
     828           0 :         return false;
     829             :     }
     830             : 
     831        3234 :     for (i = 0; i < state->key_length; i++)
     832        3136 :         result[i] = ClientKey[i] ^ ClientSignature[i];
     833             : 
     834          98 :     pg_hmac_free(ctx);
     835          98 :     return true;
     836             : }
     837             : 
     838             : /*
     839             :  * Validate the server signature, received as part of the final exchange
     840             :  * message received from the server.  *match tracks if the server signature
     841             :  * matched or not. Returns true if the server signature got verified, and
     842             :  * false for a processing error with *errstr pointing to a message about the
     843             :  * error details.
     844             :  */
     845             : static bool
     846          86 : verify_server_signature(fe_scram_state *state, bool *match,
     847             :                         const char **errstr)
     848             : {
     849             :     uint8       expected_ServerSignature[SCRAM_MAX_KEY_LEN];
     850             :     uint8       ServerKey[SCRAM_MAX_KEY_LEN];
     851             :     pg_hmac_ctx *ctx;
     852             : 
     853          86 :     ctx = pg_hmac_create(state->hash_type);
     854          86 :     if (ctx == NULL)
     855             :     {
     856           0 :         *errstr = pg_hmac_error(NULL);  /* returns OOM */
     857           0 :         return false;
     858             :     }
     859             : 
     860          86 :     if (state->conn->scram_server_key_binary)
     861             :     {
     862           8 :         memcpy(ServerKey, state->conn->scram_server_key_binary, SCRAM_MAX_KEY_LEN);
     863             :     }
     864             :     else
     865             :     {
     866          78 :         if (scram_ServerKey(state->SaltedPassword, state->hash_type,
     867             :                             state->key_length, ServerKey, errstr) < 0)
     868             :         {
     869             :             /* errstr is filled already */
     870           0 :             pg_hmac_free(ctx);
     871           0 :             return false;
     872             :         }
     873             :     }
     874             : 
     875             :     /* calculate ServerSignature */
     876         172 :     if (pg_hmac_init(ctx, ServerKey, state->key_length) < 0 ||
     877          86 :         pg_hmac_update(ctx,
     878          86 :                        (uint8 *) state->client_first_message_bare,
     879         172 :                        strlen(state->client_first_message_bare)) < 0 ||
     880         172 :         pg_hmac_update(ctx, (uint8 *) ",", 1) < 0 ||
     881          86 :         pg_hmac_update(ctx,
     882          86 :                        (uint8 *) state->server_first_message,
     883         172 :                        strlen(state->server_first_message)) < 0 ||
     884         172 :         pg_hmac_update(ctx, (uint8 *) ",", 1) < 0 ||
     885          86 :         pg_hmac_update(ctx,
     886          86 :                        (uint8 *) state->client_final_message_without_proof,
     887         172 :                        strlen(state->client_final_message_without_proof)) < 0 ||
     888          86 :         pg_hmac_final(ctx, expected_ServerSignature,
     889          86 :                       state->key_length) < 0)
     890             :     {
     891           0 :         *errstr = pg_hmac_error(ctx);
     892           0 :         pg_hmac_free(ctx);
     893           0 :         return false;
     894             :     }
     895             : 
     896          86 :     pg_hmac_free(ctx);
     897             : 
     898             :     /* signature processed, so now check after it */
     899          86 :     if (memcmp(expected_ServerSignature, state->ServerSignature,
     900          86 :                state->key_length) != 0)
     901           0 :         *match = false;
     902             :     else
     903          86 :         *match = true;
     904             : 
     905          86 :     return true;
     906             : }
     907             : 
     908             : /*
     909             :  * Build a new SCRAM secret.
     910             :  *
     911             :  * On error, returns NULL and sets *errstr to point to a message about the
     912             :  * error details.
     913             :  */
     914             : char *
     915           2 : pg_fe_scram_build_secret(const char *password, int iterations, const char **errstr)
     916             : {
     917             :     char       *prep_password;
     918             :     pg_saslprep_rc rc;
     919             :     char        saltbuf[SCRAM_DEFAULT_SALT_LEN];
     920             :     char       *result;
     921             : 
     922             :     /*
     923             :      * Normalize the password with SASLprep.  If that doesn't work, because
     924             :      * the password isn't valid UTF-8 or contains prohibited characters, just
     925             :      * proceed with the original password.  (See comments at the top of
     926             :      * auth-scram.c.)
     927             :      */
     928           2 :     rc = pg_saslprep(password, &prep_password);
     929           2 :     if (rc == SASLPREP_OOM)
     930             :     {
     931           0 :         *errstr = libpq_gettext("out of memory");
     932           0 :         return NULL;
     933             :     }
     934           2 :     if (rc == SASLPREP_SUCCESS)
     935           2 :         password = (const char *) prep_password;
     936             : 
     937             :     /* Generate a random salt */
     938           2 :     if (!pg_strong_random(saltbuf, SCRAM_DEFAULT_SALT_LEN))
     939             :     {
     940           0 :         *errstr = libpq_gettext("could not generate random salt");
     941           0 :         free(prep_password);
     942           0 :         return NULL;
     943             :     }
     944             : 
     945           2 :     result = scram_build_secret(PG_SHA256, SCRAM_SHA_256_KEY_LEN, saltbuf,
     946             :                                 SCRAM_DEFAULT_SALT_LEN,
     947             :                                 iterations, password,
     948             :                                 errstr);
     949             : 
     950           2 :     free(prep_password);
     951             : 
     952           2 :     return result;
     953             : }

Generated by: LCOV version 1.14