LCOV - code coverage report
Current view: top level - src/interfaces/libpq - fe-auth-scram.c (source / functions) Hit Total Coverage
Test: PostgreSQL 18devel Lines: 256 382 67.0 %
Date: 2025-01-18 04:15:08 Functions: 12 12 100.0 %
Legend: Lines: hit not hit

          Line data    Source code
       1             : /*-------------------------------------------------------------------------
       2             :  *
       3             :  * fe-auth-scram.c
       4             :  *     The front-end (client) implementation of SCRAM authentication.
       5             :  *
       6             :  * Portions Copyright (c) 1996-2025, PostgreSQL Global Development Group
       7             :  * Portions Copyright (c) 1994, Regents of the University of California
       8             :  *
       9             :  * IDENTIFICATION
      10             :  *    src/interfaces/libpq/fe-auth-scram.c
      11             :  *
      12             :  *-------------------------------------------------------------------------
      13             :  */
      14             : 
      15             : #include "postgres_fe.h"
      16             : 
      17             : #include "common/base64.h"
      18             : #include "common/hmac.h"
      19             : #include "common/saslprep.h"
      20             : #include "common/scram-common.h"
      21             : #include "fe-auth.h"
      22             : 
      23             : 
      24             : /* The exported SCRAM callback mechanism. */
      25             : static void *scram_init(PGconn *conn, const char *password,
      26             :                         const char *sasl_mechanism);
      27             : static SASLStatus scram_exchange(void *opaq, char *input, int inputlen,
      28             :                                  char **output, int *outputlen);
      29             : static bool scram_channel_bound(void *opaq);
      30             : static void scram_free(void *opaq);
      31             : 
      32             : const pg_fe_sasl_mech pg_scram_mech = {
      33             :     scram_init,
      34             :     scram_exchange,
      35             :     scram_channel_bound,
      36             :     scram_free
      37             : };
      38             : 
      39             : /*
      40             :  * Status of exchange messages used for SCRAM authentication via the
      41             :  * SASL protocol.
      42             :  */
      43             : typedef enum
      44             : {
      45             :     FE_SCRAM_INIT,
      46             :     FE_SCRAM_NONCE_SENT,
      47             :     FE_SCRAM_PROOF_SENT,
      48             :     FE_SCRAM_FINISHED,
      49             : } fe_scram_state_enum;
      50             : 
      51             : typedef struct
      52             : {
      53             :     fe_scram_state_enum state;
      54             : 
      55             :     /* These are supplied by the user */
      56             :     PGconn     *conn;
      57             :     char       *password;
      58             :     char       *sasl_mechanism;
      59             : 
      60             :     /* State data depending on the hash type */
      61             :     pg_cryptohash_type hash_type;
      62             :     int         key_length;
      63             : 
      64             :     /* We construct these */
      65             :     uint8       SaltedPassword[SCRAM_MAX_KEY_LEN];
      66             :     char       *client_nonce;
      67             :     char       *client_first_message_bare;
      68             :     char       *client_final_message_without_proof;
      69             : 
      70             :     /* These come from the server-first message */
      71             :     char       *server_first_message;
      72             :     char       *salt;
      73             :     int         saltlen;
      74             :     int         iterations;
      75             :     char       *nonce;
      76             : 
      77             :     /* These come from the server-final message */
      78             :     char       *server_final_message;
      79             :     char        ServerSignature[SCRAM_MAX_KEY_LEN];
      80             : } fe_scram_state;
      81             : 
      82             : static bool read_server_first_message(fe_scram_state *state, char *input);
      83             : static bool read_server_final_message(fe_scram_state *state, char *input);
      84             : static char *build_client_first_message(fe_scram_state *state);
      85             : static char *build_client_final_message(fe_scram_state *state);
      86             : static bool verify_server_signature(fe_scram_state *state, bool *match,
      87             :                                     const char **errstr);
      88             : static bool calculate_client_proof(fe_scram_state *state,
      89             :                                    const char *client_final_message_without_proof,
      90             :                                    uint8 *result, const char **errstr);
      91             : 
      92             : /*
      93             :  * Initialize SCRAM exchange status.
      94             :  */
      95             : static void *
      96          98 : scram_init(PGconn *conn,
      97             :            const char *password,
      98             :            const char *sasl_mechanism)
      99             : {
     100             :     fe_scram_state *state;
     101             :     char       *prep_password;
     102             :     pg_saslprep_rc rc;
     103             : 
     104             :     Assert(sasl_mechanism != NULL);
     105             : 
     106          98 :     state = (fe_scram_state *) malloc(sizeof(fe_scram_state));
     107          98 :     if (!state)
     108           0 :         return NULL;
     109          98 :     memset(state, 0, sizeof(fe_scram_state));
     110          98 :     state->conn = conn;
     111          98 :     state->state = FE_SCRAM_INIT;
     112          98 :     state->key_length = SCRAM_SHA_256_KEY_LEN;
     113          98 :     state->hash_type = PG_SHA256;
     114             : 
     115          98 :     state->sasl_mechanism = strdup(sasl_mechanism);
     116          98 :     if (!state->sasl_mechanism)
     117             :     {
     118           0 :         free(state);
     119           0 :         return NULL;
     120             :     }
     121             : 
     122          98 :     if (password)
     123             :     {
     124             :         /* Normalize the password with SASLprep, if possible */
     125          90 :         rc = pg_saslprep(password, &prep_password);
     126          90 :         if (rc == SASLPREP_OOM)
     127             :         {
     128           0 :             free(state->sasl_mechanism);
     129           0 :             free(state);
     130           0 :             return NULL;
     131             :         }
     132          90 :         if (rc != SASLPREP_SUCCESS)
     133             :         {
     134           4 :             prep_password = strdup(password);
     135           4 :             if (!prep_password)
     136             :             {
     137           0 :                 free(state->sasl_mechanism);
     138           0 :                 free(state);
     139           0 :                 return NULL;
     140             :             }
     141             :         }
     142          90 :         state->password = prep_password;
     143             :     }
     144             : 
     145          98 :     return state;
     146             : }
     147             : 
     148             : /*
     149             :  * Return true if channel binding was employed and the SCRAM exchange
     150             :  * completed. This should be used after a successful exchange to determine
     151             :  * whether the server authenticated itself to the client.
     152             :  *
     153             :  * Note that the caller must also ensure that the exchange was actually
     154             :  * successful.
     155             :  */
     156             : static bool
     157           6 : scram_channel_bound(void *opaq)
     158             : {
     159           6 :     fe_scram_state *state = (fe_scram_state *) opaq;
     160             : 
     161             :     /* no SCRAM exchange done */
     162           6 :     if (state == NULL)
     163           0 :         return false;
     164             : 
     165             :     /* SCRAM exchange not completed */
     166           6 :     if (state->state != FE_SCRAM_FINISHED)
     167           0 :         return false;
     168             : 
     169             :     /* channel binding mechanism not used */
     170           6 :     if (strcmp(state->sasl_mechanism, SCRAM_SHA_256_PLUS_NAME) != 0)
     171           0 :         return false;
     172             : 
     173             :     /* all clear! */
     174           6 :     return true;
     175             : }
     176             : 
     177             : /*
     178             :  * Free SCRAM exchange status
     179             :  */
     180             : static void
     181          90 : scram_free(void *opaq)
     182             : {
     183          90 :     fe_scram_state *state = (fe_scram_state *) opaq;
     184             : 
     185          90 :     free(state->password);
     186          90 :     free(state->sasl_mechanism);
     187             : 
     188             :     /* client messages */
     189          90 :     free(state->client_nonce);
     190          90 :     free(state->client_first_message_bare);
     191          90 :     free(state->client_final_message_without_proof);
     192             : 
     193             :     /* first message from server */
     194          90 :     free(state->server_first_message);
     195          90 :     free(state->salt);
     196          90 :     free(state->nonce);
     197             : 
     198             :     /* final message from server */
     199          90 :     free(state->server_final_message);
     200             : 
     201          90 :     free(state);
     202          90 : }
     203             : 
     204             : /*
     205             :  * Exchange a SCRAM message with backend.
     206             :  */
     207             : static SASLStatus
     208         282 : scram_exchange(void *opaq, char *input, int inputlen,
     209             :                char **output, int *outputlen)
     210             : {
     211         282 :     fe_scram_state *state = (fe_scram_state *) opaq;
     212         282 :     PGconn     *conn = state->conn;
     213         282 :     const char *errstr = NULL;
     214             : 
     215         282 :     *output = NULL;
     216         282 :     *outputlen = 0;
     217             : 
     218             :     /*
     219             :      * Check that the input length agrees with the string length of the input.
     220             :      * We can ignore inputlen after this.
     221             :      */
     222         282 :     if (state->state != FE_SCRAM_INIT)
     223             :     {
     224         184 :         if (inputlen == 0)
     225             :         {
     226           0 :             libpq_append_conn_error(conn, "malformed SCRAM message (empty message)");
     227           0 :             return SASL_FAILED;
     228             :         }
     229         184 :         if (inputlen != strlen(input))
     230             :         {
     231           0 :             libpq_append_conn_error(conn, "malformed SCRAM message (length mismatch)");
     232           0 :             return SASL_FAILED;
     233             :         }
     234             :     }
     235             : 
     236         282 :     switch (state->state)
     237             :     {
     238          98 :         case FE_SCRAM_INIT:
     239             :             /* Begin the SCRAM handshake, by sending client nonce */
     240          98 :             *output = build_client_first_message(state);
     241          98 :             if (*output == NULL)
     242           0 :                 return SASL_FAILED;
     243             : 
     244          98 :             *outputlen = strlen(*output);
     245          98 :             state->state = FE_SCRAM_NONCE_SENT;
     246          98 :             return SASL_CONTINUE;
     247             : 
     248          98 :         case FE_SCRAM_NONCE_SENT:
     249             :             /* Receive salt and server nonce, send response. */
     250          98 :             if (!read_server_first_message(state, input))
     251           0 :                 return SASL_FAILED;
     252             : 
     253          98 :             *output = build_client_final_message(state);
     254          98 :             if (*output == NULL)
     255           0 :                 return SASL_FAILED;
     256             : 
     257          98 :             *outputlen = strlen(*output);
     258          98 :             state->state = FE_SCRAM_PROOF_SENT;
     259          98 :             return SASL_CONTINUE;
     260             : 
     261          86 :         case FE_SCRAM_PROOF_SENT:
     262             :             {
     263             :                 bool        match;
     264             : 
     265             :                 /* Receive server signature */
     266          86 :                 if (!read_server_final_message(state, input))
     267           0 :                     return SASL_FAILED;
     268             : 
     269             :                 /*
     270             :                  * Verify server signature, to make sure we're talking to the
     271             :                  * genuine server.
     272             :                  */
     273          86 :                 if (!verify_server_signature(state, &match, &errstr))
     274             :                 {
     275           0 :                     libpq_append_conn_error(conn, "could not verify server signature: %s", errstr);
     276           0 :                     return SASL_FAILED;
     277             :                 }
     278             : 
     279          86 :                 if (!match)
     280             :                 {
     281           0 :                     libpq_append_conn_error(conn, "incorrect server signature");
     282             :                 }
     283          86 :                 state->state = FE_SCRAM_FINISHED;
     284          86 :                 state->conn->client_finished_auth = true;
     285          86 :                 return match ? SASL_COMPLETE : SASL_FAILED;
     286             :             }
     287             : 
     288           0 :         default:
     289             :             /* shouldn't happen */
     290           0 :             libpq_append_conn_error(conn, "invalid SCRAM exchange state");
     291           0 :             break;
     292             :     }
     293             : 
     294           0 :     return SASL_FAILED;
     295             : }
     296             : 
     297             : /*
     298             :  * Read value for an attribute part of a SCRAM message.
     299             :  *
     300             :  * The buffer at **input is destructively modified, and *input is
     301             :  * advanced over the "attr=value" string and any following comma.
     302             :  *
     303             :  * On failure, append an error message to *errorMessage and return NULL.
     304             :  */
     305             : static char *
     306         380 : read_attr_value(char **input, char attr, PQExpBuffer errorMessage)
     307             : {
     308         380 :     char       *begin = *input;
     309             :     char       *end;
     310             : 
     311         380 :     if (*begin != attr)
     312             :     {
     313           0 :         libpq_append_error(errorMessage,
     314             :                            "malformed SCRAM message (attribute \"%c\" expected)",
     315             :                            attr);
     316           0 :         return NULL;
     317             :     }
     318         380 :     begin++;
     319             : 
     320         380 :     if (*begin != '=')
     321             :     {
     322           0 :         libpq_append_error(errorMessage,
     323             :                            "malformed SCRAM message (expected character \"=\" for attribute \"%c\")",
     324             :                            attr);
     325           0 :         return NULL;
     326             :     }
     327         380 :     begin++;
     328             : 
     329         380 :     end = begin;
     330       11608 :     while (*end && *end != ',')
     331       11228 :         end++;
     332             : 
     333         380 :     if (*end)
     334             :     {
     335         196 :         *end = '\0';
     336         196 :         *input = end + 1;
     337             :     }
     338             :     else
     339         184 :         *input = end;
     340             : 
     341         380 :     return begin;
     342             : }
     343             : 
     344             : /*
     345             :  * Build the first exchange message sent by the client.
     346             :  */
     347             : static char *
     348          98 : build_client_first_message(fe_scram_state *state)
     349             : {
     350          98 :     PGconn     *conn = state->conn;
     351             :     char        raw_nonce[SCRAM_RAW_NONCE_LEN + 1];
     352             :     char       *result;
     353             :     int         channel_info_len;
     354             :     int         encoded_len;
     355             :     PQExpBufferData buf;
     356             : 
     357             :     /*
     358             :      * Generate a "raw" nonce.  This is converted to ASCII-printable form by
     359             :      * base64-encoding it.
     360             :      */
     361          98 :     if (!pg_strong_random(raw_nonce, SCRAM_RAW_NONCE_LEN))
     362             :     {
     363           0 :         libpq_append_conn_error(conn, "could not generate nonce");
     364           0 :         return NULL;
     365             :     }
     366             : 
     367          98 :     encoded_len = pg_b64_enc_len(SCRAM_RAW_NONCE_LEN);
     368             :     /* don't forget the zero-terminator */
     369          98 :     state->client_nonce = malloc(encoded_len + 1);
     370          98 :     if (state->client_nonce == NULL)
     371             :     {
     372           0 :         libpq_append_conn_error(conn, "out of memory");
     373           0 :         return NULL;
     374             :     }
     375          98 :     encoded_len = pg_b64_encode(raw_nonce, SCRAM_RAW_NONCE_LEN,
     376             :                                 state->client_nonce, encoded_len);
     377          98 :     if (encoded_len < 0)
     378             :     {
     379           0 :         libpq_append_conn_error(conn, "could not encode nonce");
     380           0 :         return NULL;
     381             :     }
     382          98 :     state->client_nonce[encoded_len] = '\0';
     383             : 
     384             :     /*
     385             :      * Generate message.  The username is left empty as the backend uses the
     386             :      * value provided by the startup packet.  Also, as this username is not
     387             :      * prepared with SASLprep, the message parsing would fail if it includes
     388             :      * '=' or ',' characters.
     389             :      */
     390             : 
     391          98 :     initPQExpBuffer(&buf);
     392             : 
     393             :     /*
     394             :      * First build the gs2-header with channel binding information.
     395             :      */
     396          98 :     if (strcmp(state->sasl_mechanism, SCRAM_SHA_256_PLUS_NAME) == 0)
     397             :     {
     398             :         Assert(conn->ssl_in_use);
     399          10 :         appendPQExpBufferStr(&buf, "p=tls-server-end-point");
     400             :     }
     401             : #ifdef USE_SSL
     402          88 :     else if (conn->channel_binding[0] != 'd' && /* disable */
     403          84 :              conn->ssl_in_use)
     404             :     {
     405             :         /*
     406             :          * Client supports channel binding, but thinks the server does not.
     407             :          */
     408           0 :         appendPQExpBufferChar(&buf, 'y');
     409             :     }
     410             : #endif
     411             :     else
     412             :     {
     413             :         /*
     414             :          * Client does not support channel binding, or has disabled it.
     415             :          */
     416          88 :         appendPQExpBufferChar(&buf, 'n');
     417             :     }
     418             : 
     419          98 :     if (PQExpBufferDataBroken(buf))
     420           0 :         goto oom_error;
     421             : 
     422          98 :     channel_info_len = buf.len;
     423             : 
     424          98 :     appendPQExpBuffer(&buf, ",,n=,r=%s", state->client_nonce);
     425          98 :     if (PQExpBufferDataBroken(buf))
     426           0 :         goto oom_error;
     427             : 
     428             :     /*
     429             :      * The first message content needs to be saved without channel binding
     430             :      * information.
     431             :      */
     432          98 :     state->client_first_message_bare = strdup(buf.data + channel_info_len + 2);
     433          98 :     if (!state->client_first_message_bare)
     434           0 :         goto oom_error;
     435             : 
     436          98 :     result = strdup(buf.data);
     437          98 :     if (result == NULL)
     438           0 :         goto oom_error;
     439             : 
     440          98 :     termPQExpBuffer(&buf);
     441          98 :     return result;
     442             : 
     443           0 : oom_error:
     444           0 :     termPQExpBuffer(&buf);
     445           0 :     libpq_append_conn_error(conn, "out of memory");
     446           0 :     return NULL;
     447             : }
     448             : 
     449             : /*
     450             :  * Build the final exchange message sent from the client.
     451             :  */
     452             : static char *
     453          98 : build_client_final_message(fe_scram_state *state)
     454             : {
     455             :     PQExpBufferData buf;
     456          98 :     PGconn     *conn = state->conn;
     457             :     uint8       client_proof[SCRAM_MAX_KEY_LEN];
     458             :     char       *result;
     459             :     int         encoded_len;
     460          98 :     const char *errstr = NULL;
     461             : 
     462          98 :     initPQExpBuffer(&buf);
     463             : 
     464             :     /*
     465             :      * Construct client-final-message-without-proof.  We need to remember it
     466             :      * for verifying the server proof in the final step of authentication.
     467             :      *
     468             :      * The channel binding flag handling (p/y/n) must be consistent with
     469             :      * build_client_first_message(), because the server will check that it's
     470             :      * the same flag both times.
     471             :      */
     472          98 :     if (strcmp(state->sasl_mechanism, SCRAM_SHA_256_PLUS_NAME) == 0)
     473             :     {
     474             : #ifdef USE_SSL
     475          10 :         char       *cbind_data = NULL;
     476          10 :         size_t      cbind_data_len = 0;
     477             :         size_t      cbind_header_len;
     478             :         char       *cbind_input;
     479             :         size_t      cbind_input_len;
     480             :         int         encoded_cbind_len;
     481             : 
     482             :         /* Fetch hash data of server's SSL certificate */
     483             :         cbind_data =
     484          10 :             pgtls_get_peer_certificate_hash(state->conn,
     485             :                                             &cbind_data_len);
     486          10 :         if (cbind_data == NULL)
     487             :         {
     488             :             /* error message is already set on error */
     489           0 :             termPQExpBuffer(&buf);
     490           0 :             return NULL;
     491             :         }
     492             : 
     493          10 :         appendPQExpBufferStr(&buf, "c=");
     494             : 
     495             :         /* p=type,, */
     496          10 :         cbind_header_len = strlen("p=tls-server-end-point,,");
     497          10 :         cbind_input_len = cbind_header_len + cbind_data_len;
     498          10 :         cbind_input = malloc(cbind_input_len);
     499          10 :         if (!cbind_input)
     500             :         {
     501           0 :             free(cbind_data);
     502           0 :             goto oom_error;
     503             :         }
     504          10 :         memcpy(cbind_input, "p=tls-server-end-point,,", cbind_header_len);
     505          10 :         memcpy(cbind_input + cbind_header_len, cbind_data, cbind_data_len);
     506             : 
     507          10 :         encoded_cbind_len = pg_b64_enc_len(cbind_input_len);
     508          10 :         if (!enlargePQExpBuffer(&buf, encoded_cbind_len))
     509             :         {
     510           0 :             free(cbind_data);
     511           0 :             free(cbind_input);
     512           0 :             goto oom_error;
     513             :         }
     514          10 :         encoded_cbind_len = pg_b64_encode(cbind_input, cbind_input_len,
     515          10 :                                           buf.data + buf.len,
     516             :                                           encoded_cbind_len);
     517          10 :         if (encoded_cbind_len < 0)
     518             :         {
     519           0 :             free(cbind_data);
     520           0 :             free(cbind_input);
     521           0 :             termPQExpBuffer(&buf);
     522           0 :             appendPQExpBufferStr(&conn->errorMessage,
     523             :                                  "could not encode cbind data for channel binding\n");
     524           0 :             return NULL;
     525             :         }
     526          10 :         buf.len += encoded_cbind_len;
     527          10 :         buf.data[buf.len] = '\0';
     528             : 
     529          10 :         free(cbind_data);
     530          10 :         free(cbind_input);
     531             : #else
     532             :         /*
     533             :          * Chose channel binding, but the SSL library doesn't support it.
     534             :          * Shouldn't happen.
     535             :          */
     536             :         termPQExpBuffer(&buf);
     537             :         appendPQExpBufferStr(&conn->errorMessage,
     538             :                              "channel binding not supported by this build\n");
     539             :         return NULL;
     540             : #endif                          /* USE_SSL */
     541             :     }
     542             : #ifdef USE_SSL
     543          88 :     else if (conn->channel_binding[0] != 'd' && /* disable */
     544          84 :              conn->ssl_in_use)
     545           0 :         appendPQExpBufferStr(&buf, "c=eSws"); /* base64 of "y,," */
     546             : #endif
     547             :     else
     548          88 :         appendPQExpBufferStr(&buf, "c=biws"); /* base64 of "n,," */
     549             : 
     550          98 :     if (PQExpBufferDataBroken(buf))
     551           0 :         goto oom_error;
     552             : 
     553          98 :     appendPQExpBuffer(&buf, ",r=%s", state->nonce);
     554          98 :     if (PQExpBufferDataBroken(buf))
     555           0 :         goto oom_error;
     556             : 
     557          98 :     state->client_final_message_without_proof = strdup(buf.data);
     558          98 :     if (state->client_final_message_without_proof == NULL)
     559           0 :         goto oom_error;
     560             : 
     561             :     /* Append proof to it, to form client-final-message. */
     562          98 :     if (!calculate_client_proof(state,
     563          98 :                                 state->client_final_message_without_proof,
     564             :                                 client_proof, &errstr))
     565             :     {
     566           0 :         termPQExpBuffer(&buf);
     567           0 :         libpq_append_conn_error(conn, "could not calculate client proof: %s", errstr);
     568           0 :         return NULL;
     569             :     }
     570             : 
     571          98 :     appendPQExpBufferStr(&buf, ",p=");
     572          98 :     encoded_len = pg_b64_enc_len(state->key_length);
     573          98 :     if (!enlargePQExpBuffer(&buf, encoded_len))
     574           0 :         goto oom_error;
     575          98 :     encoded_len = pg_b64_encode((char *) client_proof,
     576             :                                 state->key_length,
     577          98 :                                 buf.data + buf.len,
     578             :                                 encoded_len);
     579          98 :     if (encoded_len < 0)
     580             :     {
     581           0 :         termPQExpBuffer(&buf);
     582           0 :         libpq_append_conn_error(conn, "could not encode client proof");
     583           0 :         return NULL;
     584             :     }
     585          98 :     buf.len += encoded_len;
     586          98 :     buf.data[buf.len] = '\0';
     587             : 
     588          98 :     result = strdup(buf.data);
     589          98 :     if (result == NULL)
     590           0 :         goto oom_error;
     591             : 
     592          98 :     termPQExpBuffer(&buf);
     593          98 :     return result;
     594             : 
     595           0 : oom_error:
     596           0 :     termPQExpBuffer(&buf);
     597           0 :     libpq_append_conn_error(conn, "out of memory");
     598           0 :     return NULL;
     599             : }
     600             : 
     601             : /*
     602             :  * Read the first exchange message coming from the server.
     603             :  */
     604             : static bool
     605          98 : read_server_first_message(fe_scram_state *state, char *input)
     606             : {
     607          98 :     PGconn     *conn = state->conn;
     608             :     char       *iterations_str;
     609             :     char       *endptr;
     610             :     char       *encoded_salt;
     611             :     char       *nonce;
     612             :     int         decoded_salt_len;
     613             : 
     614          98 :     state->server_first_message = strdup(input);
     615          98 :     if (state->server_first_message == NULL)
     616             :     {
     617           0 :         libpq_append_conn_error(conn, "out of memory");
     618           0 :         return false;
     619             :     }
     620             : 
     621             :     /* parse the message */
     622          98 :     nonce = read_attr_value(&input, 'r',
     623             :                             &conn->errorMessage);
     624          98 :     if (nonce == NULL)
     625             :     {
     626             :         /* read_attr_value() has appended an error string */
     627           0 :         return false;
     628             :     }
     629             : 
     630             :     /* Verify immediately that the server used our part of the nonce */
     631          98 :     if (strlen(nonce) < strlen(state->client_nonce) ||
     632          98 :         memcmp(nonce, state->client_nonce, strlen(state->client_nonce)) != 0)
     633             :     {
     634           0 :         libpq_append_conn_error(conn, "invalid SCRAM response (nonce mismatch)");
     635           0 :         return false;
     636             :     }
     637             : 
     638          98 :     state->nonce = strdup(nonce);
     639          98 :     if (state->nonce == NULL)
     640             :     {
     641           0 :         libpq_append_conn_error(conn, "out of memory");
     642           0 :         return false;
     643             :     }
     644             : 
     645          98 :     encoded_salt = read_attr_value(&input, 's', &conn->errorMessage);
     646          98 :     if (encoded_salt == NULL)
     647             :     {
     648             :         /* read_attr_value() has appended an error string */
     649           0 :         return false;
     650             :     }
     651          98 :     decoded_salt_len = pg_b64_dec_len(strlen(encoded_salt));
     652          98 :     state->salt = malloc(decoded_salt_len);
     653          98 :     if (state->salt == NULL)
     654             :     {
     655           0 :         libpq_append_conn_error(conn, "out of memory");
     656           0 :         return false;
     657             :     }
     658         196 :     state->saltlen = pg_b64_decode(encoded_salt,
     659          98 :                                    strlen(encoded_salt),
     660             :                                    state->salt,
     661             :                                    decoded_salt_len);
     662          98 :     if (state->saltlen < 0)
     663             :     {
     664           0 :         libpq_append_conn_error(conn, "malformed SCRAM message (invalid salt)");
     665           0 :         return false;
     666             :     }
     667             : 
     668          98 :     iterations_str = read_attr_value(&input, 'i', &conn->errorMessage);
     669          98 :     if (iterations_str == NULL)
     670             :     {
     671             :         /* read_attr_value() has appended an error string */
     672           0 :         return false;
     673             :     }
     674          98 :     state->iterations = strtol(iterations_str, &endptr, 10);
     675          98 :     if (*endptr != '\0' || state->iterations < 1)
     676             :     {
     677           0 :         libpq_append_conn_error(conn, "malformed SCRAM message (invalid iteration count)");
     678           0 :         return false;
     679             :     }
     680             : 
     681          98 :     if (*input != '\0')
     682           0 :         libpq_append_conn_error(conn, "malformed SCRAM message (garbage at end of server-first-message)");
     683             : 
     684          98 :     return true;
     685             : }
     686             : 
     687             : /*
     688             :  * Read the final exchange message coming from the server.
     689             :  */
     690             : static bool
     691          86 : read_server_final_message(fe_scram_state *state, char *input)
     692             : {
     693          86 :     PGconn     *conn = state->conn;
     694             :     char       *encoded_server_signature;
     695             :     char       *decoded_server_signature;
     696             :     int         server_signature_len;
     697             : 
     698          86 :     state->server_final_message = strdup(input);
     699          86 :     if (!state->server_final_message)
     700             :     {
     701           0 :         libpq_append_conn_error(conn, "out of memory");
     702           0 :         return false;
     703             :     }
     704             : 
     705             :     /* Check for error result. */
     706          86 :     if (*input == 'e')
     707             :     {
     708           0 :         char       *errmsg = read_attr_value(&input, 'e',
     709             :                                              &conn->errorMessage);
     710             : 
     711           0 :         if (errmsg == NULL)
     712             :         {
     713             :             /* read_attr_value() has appended an error message */
     714           0 :             return false;
     715             :         }
     716           0 :         libpq_append_conn_error(conn, "error received from server in SCRAM exchange: %s",
     717             :                                 errmsg);
     718           0 :         return false;
     719             :     }
     720             : 
     721             :     /* Parse the message. */
     722          86 :     encoded_server_signature = read_attr_value(&input, 'v',
     723             :                                                &conn->errorMessage);
     724          86 :     if (encoded_server_signature == NULL)
     725             :     {
     726             :         /* read_attr_value() has appended an error message */
     727           0 :         return false;
     728             :     }
     729             : 
     730          86 :     if (*input != '\0')
     731           0 :         libpq_append_conn_error(conn, "malformed SCRAM message (garbage at end of server-final-message)");
     732             : 
     733          86 :     server_signature_len = pg_b64_dec_len(strlen(encoded_server_signature));
     734          86 :     decoded_server_signature = malloc(server_signature_len);
     735          86 :     if (!decoded_server_signature)
     736             :     {
     737           0 :         libpq_append_conn_error(conn, "out of memory");
     738           0 :         return false;
     739             :     }
     740             : 
     741          86 :     server_signature_len = pg_b64_decode(encoded_server_signature,
     742          86 :                                          strlen(encoded_server_signature),
     743             :                                          decoded_server_signature,
     744             :                                          server_signature_len);
     745          86 :     if (server_signature_len != state->key_length)
     746             :     {
     747           0 :         free(decoded_server_signature);
     748           0 :         libpq_append_conn_error(conn, "malformed SCRAM message (invalid server signature)");
     749           0 :         return false;
     750             :     }
     751          86 :     memcpy(state->ServerSignature, decoded_server_signature,
     752          86 :            state->key_length);
     753          86 :     free(decoded_server_signature);
     754             : 
     755          86 :     return true;
     756             : }
     757             : 
     758             : /*
     759             :  * Calculate the client proof, part of the final exchange message sent
     760             :  * by the client.  Returns true on success, false on failure with *errstr
     761             :  * pointing to a message about the error details.
     762             :  */
     763             : static bool
     764          98 : calculate_client_proof(fe_scram_state *state,
     765             :                        const char *client_final_message_without_proof,
     766             :                        uint8 *result, const char **errstr)
     767             : {
     768             :     uint8       StoredKey[SCRAM_MAX_KEY_LEN];
     769             :     uint8       ClientKey[SCRAM_MAX_KEY_LEN];
     770             :     uint8       ClientSignature[SCRAM_MAX_KEY_LEN];
     771             :     int         i;
     772             :     pg_hmac_ctx *ctx;
     773             : 
     774          98 :     ctx = pg_hmac_create(state->hash_type);
     775          98 :     if (ctx == NULL)
     776             :     {
     777           0 :         *errstr = pg_hmac_error(NULL);  /* returns OOM */
     778           0 :         return false;
     779             :     }
     780             : 
     781          98 :     if (state->conn->scram_client_key_binary)
     782             :     {
     783           8 :         memcpy(ClientKey, state->conn->scram_client_key_binary, SCRAM_MAX_KEY_LEN);
     784             :     }
     785             :     else
     786             :     {
     787             :         /*
     788             :          * Calculate SaltedPassword, and store it in 'state' so that we can
     789             :          * reuse it later in verify_server_signature.
     790             :          */
     791          90 :         if (scram_SaltedPassword(state->password, state->hash_type,
     792          90 :                                  state->key_length, state->salt, state->saltlen,
     793          90 :                                  state->iterations, state->SaltedPassword,
     794          90 :                                  errstr) < 0 ||
     795          90 :             scram_ClientKey(state->SaltedPassword, state->hash_type,
     796             :                             state->key_length, ClientKey, errstr) < 0)
     797             :         {
     798             :             /* errstr is already filled here */
     799           0 :             pg_hmac_free(ctx);
     800           0 :             return false;
     801             :         }
     802             :     }
     803             : 
     804          98 :     if (scram_H(ClientKey, state->hash_type, state->key_length, StoredKey, errstr) < 0)
     805             :     {
     806           0 :         pg_hmac_free(ctx);
     807           0 :         return false;
     808             :     }
     809             : 
     810         196 :     if (pg_hmac_init(ctx, StoredKey, state->key_length) < 0 ||
     811          98 :         pg_hmac_update(ctx,
     812          98 :                        (uint8 *) state->client_first_message_bare,
     813         196 :                        strlen(state->client_first_message_bare)) < 0 ||
     814         196 :         pg_hmac_update(ctx, (uint8 *) ",", 1) < 0 ||
     815          98 :         pg_hmac_update(ctx,
     816          98 :                        (uint8 *) state->server_first_message,
     817         196 :                        strlen(state->server_first_message)) < 0 ||
     818         196 :         pg_hmac_update(ctx, (uint8 *) ",", 1) < 0 ||
     819          98 :         pg_hmac_update(ctx,
     820             :                        (uint8 *) client_final_message_without_proof,
     821          98 :                        strlen(client_final_message_without_proof)) < 0 ||
     822          98 :         pg_hmac_final(ctx, ClientSignature, state->key_length) < 0)
     823             :     {
     824           0 :         *errstr = pg_hmac_error(ctx);
     825           0 :         pg_hmac_free(ctx);
     826           0 :         return false;
     827             :     }
     828             : 
     829        3234 :     for (i = 0; i < state->key_length; i++)
     830        3136 :         result[i] = ClientKey[i] ^ ClientSignature[i];
     831             : 
     832          98 :     pg_hmac_free(ctx);
     833          98 :     return true;
     834             : }
     835             : 
     836             : /*
     837             :  * Validate the server signature, received as part of the final exchange
     838             :  * message received from the server.  *match tracks if the server signature
     839             :  * matched or not. Returns true if the server signature got verified, and
     840             :  * false for a processing error with *errstr pointing to a message about the
     841             :  * error details.
     842             :  */
     843             : static bool
     844          86 : verify_server_signature(fe_scram_state *state, bool *match,
     845             :                         const char **errstr)
     846             : {
     847             :     uint8       expected_ServerSignature[SCRAM_MAX_KEY_LEN];
     848             :     uint8       ServerKey[SCRAM_MAX_KEY_LEN];
     849             :     pg_hmac_ctx *ctx;
     850             : 
     851          86 :     ctx = pg_hmac_create(state->hash_type);
     852          86 :     if (ctx == NULL)
     853             :     {
     854           0 :         *errstr = pg_hmac_error(NULL);  /* returns OOM */
     855           0 :         return false;
     856             :     }
     857             : 
     858          86 :     if (state->conn->scram_server_key_binary)
     859             :     {
     860           8 :         memcpy(ServerKey, state->conn->scram_server_key_binary, SCRAM_MAX_KEY_LEN);
     861             :     }
     862             :     else
     863             :     {
     864          78 :         if (scram_ServerKey(state->SaltedPassword, state->hash_type,
     865             :                             state->key_length, ServerKey, errstr) < 0)
     866             :         {
     867             :             /* errstr is filled already */
     868           0 :             pg_hmac_free(ctx);
     869           0 :             return false;
     870             :         }
     871             :     }
     872             : 
     873             :     /* calculate ServerSignature */
     874         172 :     if (pg_hmac_init(ctx, ServerKey, state->key_length) < 0 ||
     875          86 :         pg_hmac_update(ctx,
     876          86 :                        (uint8 *) state->client_first_message_bare,
     877         172 :                        strlen(state->client_first_message_bare)) < 0 ||
     878         172 :         pg_hmac_update(ctx, (uint8 *) ",", 1) < 0 ||
     879          86 :         pg_hmac_update(ctx,
     880          86 :                        (uint8 *) state->server_first_message,
     881         172 :                        strlen(state->server_first_message)) < 0 ||
     882         172 :         pg_hmac_update(ctx, (uint8 *) ",", 1) < 0 ||
     883          86 :         pg_hmac_update(ctx,
     884          86 :                        (uint8 *) state->client_final_message_without_proof,
     885         172 :                        strlen(state->client_final_message_without_proof)) < 0 ||
     886          86 :         pg_hmac_final(ctx, expected_ServerSignature,
     887          86 :                       state->key_length) < 0)
     888             :     {
     889           0 :         *errstr = pg_hmac_error(ctx);
     890           0 :         pg_hmac_free(ctx);
     891           0 :         return false;
     892             :     }
     893             : 
     894          86 :     pg_hmac_free(ctx);
     895             : 
     896             :     /* signature processed, so now check after it */
     897          86 :     if (memcmp(expected_ServerSignature, state->ServerSignature,
     898          86 :                state->key_length) != 0)
     899           0 :         *match = false;
     900             :     else
     901          86 :         *match = true;
     902             : 
     903          86 :     return true;
     904             : }
     905             : 
     906             : /*
     907             :  * Build a new SCRAM secret.
     908             :  *
     909             :  * On error, returns NULL and sets *errstr to point to a message about the
     910             :  * error details.
     911             :  */
     912             : char *
     913           2 : pg_fe_scram_build_secret(const char *password, int iterations, const char **errstr)
     914             : {
     915             :     char       *prep_password;
     916             :     pg_saslprep_rc rc;
     917             :     char        saltbuf[SCRAM_DEFAULT_SALT_LEN];
     918             :     char       *result;
     919             : 
     920             :     /*
     921             :      * Normalize the password with SASLprep.  If that doesn't work, because
     922             :      * the password isn't valid UTF-8 or contains prohibited characters, just
     923             :      * proceed with the original password.  (See comments at the top of
     924             :      * auth-scram.c.)
     925             :      */
     926           2 :     rc = pg_saslprep(password, &prep_password);
     927           2 :     if (rc == SASLPREP_OOM)
     928             :     {
     929           0 :         *errstr = libpq_gettext("out of memory");
     930           0 :         return NULL;
     931             :     }
     932           2 :     if (rc == SASLPREP_SUCCESS)
     933           2 :         password = (const char *) prep_password;
     934             : 
     935             :     /* Generate a random salt */
     936           2 :     if (!pg_strong_random(saltbuf, SCRAM_DEFAULT_SALT_LEN))
     937             :     {
     938           0 :         *errstr = libpq_gettext("could not generate random salt");
     939           0 :         free(prep_password);
     940           0 :         return NULL;
     941             :     }
     942             : 
     943           2 :     result = scram_build_secret(PG_SHA256, SCRAM_SHA_256_KEY_LEN, saltbuf,
     944             :                                 SCRAM_DEFAULT_SALT_LEN,
     945             :                                 iterations, password,
     946             :                                 errstr);
     947             : 
     948           2 :     free(prep_password);
     949             : 
     950           2 :     return result;
     951             : }

Generated by: LCOV version 1.14