LCOV - code coverage report
Current view: top level - src/interfaces/libpq - fe-auth-scram.c (source / functions) Hit Total Coverage
Test: PostgreSQL 18devel Lines: 252 376 67.0 %
Date: 2024-11-21 08:14:44 Functions: 12 12 100.0 %
Legend: Lines: hit not hit

          Line data    Source code
       1             : /*-------------------------------------------------------------------------
       2             :  *
       3             :  * fe-auth-scram.c
       4             :  *     The front-end (client) implementation of SCRAM authentication.
       5             :  *
       6             :  * Portions Copyright (c) 1996-2024, PostgreSQL Global Development Group
       7             :  * Portions Copyright (c) 1994, Regents of the University of California
       8             :  *
       9             :  * IDENTIFICATION
      10             :  *    src/interfaces/libpq/fe-auth-scram.c
      11             :  *
      12             :  *-------------------------------------------------------------------------
      13             :  */
      14             : 
      15             : #include "postgres_fe.h"
      16             : 
      17             : #include "common/base64.h"
      18             : #include "common/hmac.h"
      19             : #include "common/saslprep.h"
      20             : #include "common/scram-common.h"
      21             : #include "fe-auth.h"
      22             : 
      23             : 
      24             : /* The exported SCRAM callback mechanism. */
      25             : static void *scram_init(PGconn *conn, const char *password,
      26             :                         const char *sasl_mechanism);
      27             : static SASLStatus scram_exchange(void *opaq, char *input, int inputlen,
      28             :                                  char **output, int *outputlen);
      29             : static bool scram_channel_bound(void *opaq);
      30             : static void scram_free(void *opaq);
      31             : 
      32             : const pg_fe_sasl_mech pg_scram_mech = {
      33             :     scram_init,
      34             :     scram_exchange,
      35             :     scram_channel_bound,
      36             :     scram_free
      37             : };
      38             : 
      39             : /*
      40             :  * Status of exchange messages used for SCRAM authentication via the
      41             :  * SASL protocol.
      42             :  */
      43             : typedef enum
      44             : {
      45             :     FE_SCRAM_INIT,
      46             :     FE_SCRAM_NONCE_SENT,
      47             :     FE_SCRAM_PROOF_SENT,
      48             :     FE_SCRAM_FINISHED,
      49             : } fe_scram_state_enum;
      50             : 
      51             : typedef struct
      52             : {
      53             :     fe_scram_state_enum state;
      54             : 
      55             :     /* These are supplied by the user */
      56             :     PGconn     *conn;
      57             :     char       *password;
      58             :     char       *sasl_mechanism;
      59             : 
      60             :     /* State data depending on the hash type */
      61             :     pg_cryptohash_type hash_type;
      62             :     int         key_length;
      63             : 
      64             :     /* We construct these */
      65             :     uint8       SaltedPassword[SCRAM_MAX_KEY_LEN];
      66             :     char       *client_nonce;
      67             :     char       *client_first_message_bare;
      68             :     char       *client_final_message_without_proof;
      69             : 
      70             :     /* These come from the server-first message */
      71             :     char       *server_first_message;
      72             :     char       *salt;
      73             :     int         saltlen;
      74             :     int         iterations;
      75             :     char       *nonce;
      76             : 
      77             :     /* These come from the server-final message */
      78             :     char       *server_final_message;
      79             :     char        ServerSignature[SCRAM_MAX_KEY_LEN];
      80             : } fe_scram_state;
      81             : 
      82             : static bool read_server_first_message(fe_scram_state *state, char *input);
      83             : static bool read_server_final_message(fe_scram_state *state, char *input);
      84             : static char *build_client_first_message(fe_scram_state *state);
      85             : static char *build_client_final_message(fe_scram_state *state);
      86             : static bool verify_server_signature(fe_scram_state *state, bool *match,
      87             :                                     const char **errstr);
      88             : static bool calculate_client_proof(fe_scram_state *state,
      89             :                                    const char *client_final_message_without_proof,
      90             :                                    uint8 *result, const char **errstr);
      91             : 
      92             : /*
      93             :  * Initialize SCRAM exchange status.
      94             :  */
      95             : static void *
      96          80 : scram_init(PGconn *conn,
      97             :            const char *password,
      98             :            const char *sasl_mechanism)
      99             : {
     100             :     fe_scram_state *state;
     101             :     char       *prep_password;
     102             :     pg_saslprep_rc rc;
     103             : 
     104             :     Assert(sasl_mechanism != NULL);
     105             : 
     106          80 :     state = (fe_scram_state *) malloc(sizeof(fe_scram_state));
     107          80 :     if (!state)
     108           0 :         return NULL;
     109          80 :     memset(state, 0, sizeof(fe_scram_state));
     110          80 :     state->conn = conn;
     111          80 :     state->state = FE_SCRAM_INIT;
     112          80 :     state->key_length = SCRAM_SHA_256_KEY_LEN;
     113          80 :     state->hash_type = PG_SHA256;
     114             : 
     115          80 :     state->sasl_mechanism = strdup(sasl_mechanism);
     116          80 :     if (!state->sasl_mechanism)
     117             :     {
     118           0 :         free(state);
     119           0 :         return NULL;
     120             :     }
     121             : 
     122             :     /* Normalize the password with SASLprep, if possible */
     123          80 :     rc = pg_saslprep(password, &prep_password);
     124          80 :     if (rc == SASLPREP_OOM)
     125             :     {
     126           0 :         free(state->sasl_mechanism);
     127           0 :         free(state);
     128           0 :         return NULL;
     129             :     }
     130          80 :     if (rc != SASLPREP_SUCCESS)
     131             :     {
     132           4 :         prep_password = strdup(password);
     133           4 :         if (!prep_password)
     134             :         {
     135           0 :             free(state->sasl_mechanism);
     136           0 :             free(state);
     137           0 :             return NULL;
     138             :         }
     139             :     }
     140          80 :     state->password = prep_password;
     141             : 
     142          80 :     return state;
     143             : }
     144             : 
     145             : /*
     146             :  * Return true if channel binding was employed and the SCRAM exchange
     147             :  * completed. This should be used after a successful exchange to determine
     148             :  * whether the server authenticated itself to the client.
     149             :  *
     150             :  * Note that the caller must also ensure that the exchange was actually
     151             :  * successful.
     152             :  */
     153             : static bool
     154           6 : scram_channel_bound(void *opaq)
     155             : {
     156           6 :     fe_scram_state *state = (fe_scram_state *) opaq;
     157             : 
     158             :     /* no SCRAM exchange done */
     159           6 :     if (state == NULL)
     160           0 :         return false;
     161             : 
     162             :     /* SCRAM exchange not completed */
     163           6 :     if (state->state != FE_SCRAM_FINISHED)
     164           0 :         return false;
     165             : 
     166             :     /* channel binding mechanism not used */
     167           6 :     if (strcmp(state->sasl_mechanism, SCRAM_SHA_256_PLUS_NAME) != 0)
     168           0 :         return false;
     169             : 
     170             :     /* all clear! */
     171           6 :     return true;
     172             : }
     173             : 
     174             : /*
     175             :  * Free SCRAM exchange status
     176             :  */
     177             : static void
     178          80 : scram_free(void *opaq)
     179             : {
     180          80 :     fe_scram_state *state = (fe_scram_state *) opaq;
     181             : 
     182          80 :     free(state->password);
     183          80 :     free(state->sasl_mechanism);
     184             : 
     185             :     /* client messages */
     186          80 :     free(state->client_nonce);
     187          80 :     free(state->client_first_message_bare);
     188          80 :     free(state->client_final_message_without_proof);
     189             : 
     190             :     /* first message from server */
     191          80 :     free(state->server_first_message);
     192          80 :     free(state->salt);
     193          80 :     free(state->nonce);
     194             : 
     195             :     /* final message from server */
     196          80 :     free(state->server_final_message);
     197             : 
     198          80 :     free(state);
     199          80 : }
     200             : 
     201             : /*
     202             :  * Exchange a SCRAM message with backend.
     203             :  */
     204             : static SASLStatus
     205         228 : scram_exchange(void *opaq, char *input, int inputlen,
     206             :                char **output, int *outputlen)
     207             : {
     208         228 :     fe_scram_state *state = (fe_scram_state *) opaq;
     209         228 :     PGconn     *conn = state->conn;
     210         228 :     const char *errstr = NULL;
     211             : 
     212         228 :     *output = NULL;
     213         228 :     *outputlen = 0;
     214             : 
     215             :     /*
     216             :      * Check that the input length agrees with the string length of the input.
     217             :      * We can ignore inputlen after this.
     218             :      */
     219         228 :     if (state->state != FE_SCRAM_INIT)
     220             :     {
     221         148 :         if (inputlen == 0)
     222             :         {
     223           0 :             libpq_append_conn_error(conn, "malformed SCRAM message (empty message)");
     224           0 :             return SASL_FAILED;
     225             :         }
     226         148 :         if (inputlen != strlen(input))
     227             :         {
     228           0 :             libpq_append_conn_error(conn, "malformed SCRAM message (length mismatch)");
     229           0 :             return SASL_FAILED;
     230             :         }
     231             :     }
     232             : 
     233         228 :     switch (state->state)
     234             :     {
     235          80 :         case FE_SCRAM_INIT:
     236             :             /* Begin the SCRAM handshake, by sending client nonce */
     237          80 :             *output = build_client_first_message(state);
     238          80 :             if (*output == NULL)
     239           0 :                 return SASL_FAILED;
     240             : 
     241          80 :             *outputlen = strlen(*output);
     242          80 :             state->state = FE_SCRAM_NONCE_SENT;
     243          80 :             return SASL_CONTINUE;
     244             : 
     245          80 :         case FE_SCRAM_NONCE_SENT:
     246             :             /* Receive salt and server nonce, send response. */
     247          80 :             if (!read_server_first_message(state, input))
     248           0 :                 return SASL_FAILED;
     249             : 
     250          80 :             *output = build_client_final_message(state);
     251          80 :             if (*output == NULL)
     252           0 :                 return SASL_FAILED;
     253             : 
     254          80 :             *outputlen = strlen(*output);
     255          80 :             state->state = FE_SCRAM_PROOF_SENT;
     256          80 :             return SASL_CONTINUE;
     257             : 
     258          68 :         case FE_SCRAM_PROOF_SENT:
     259             :             {
     260             :                 bool        match;
     261             : 
     262             :                 /* Receive server signature */
     263          68 :                 if (!read_server_final_message(state, input))
     264           0 :                     return SASL_FAILED;
     265             : 
     266             :                 /*
     267             :                  * Verify server signature, to make sure we're talking to the
     268             :                  * genuine server.
     269             :                  */
     270          68 :                 if (!verify_server_signature(state, &match, &errstr))
     271             :                 {
     272           0 :                     libpq_append_conn_error(conn, "could not verify server signature: %s", errstr);
     273           0 :                     return SASL_FAILED;
     274             :                 }
     275             : 
     276          68 :                 if (!match)
     277             :                 {
     278           0 :                     libpq_append_conn_error(conn, "incorrect server signature");
     279             :                 }
     280          68 :                 state->state = FE_SCRAM_FINISHED;
     281          68 :                 state->conn->client_finished_auth = true;
     282          68 :                 return match ? SASL_COMPLETE : SASL_FAILED;
     283             :             }
     284             : 
     285           0 :         default:
     286             :             /* shouldn't happen */
     287           0 :             libpq_append_conn_error(conn, "invalid SCRAM exchange state");
     288           0 :             break;
     289             :     }
     290             : 
     291           0 :     return SASL_FAILED;
     292             : }
     293             : 
     294             : /*
     295             :  * Read value for an attribute part of a SCRAM message.
     296             :  *
     297             :  * The buffer at **input is destructively modified, and *input is
     298             :  * advanced over the "attr=value" string and any following comma.
     299             :  *
     300             :  * On failure, append an error message to *errorMessage and return NULL.
     301             :  */
     302             : static char *
     303         308 : read_attr_value(char **input, char attr, PQExpBuffer errorMessage)
     304             : {
     305         308 :     char       *begin = *input;
     306             :     char       *end;
     307             : 
     308         308 :     if (*begin != attr)
     309             :     {
     310           0 :         libpq_append_error(errorMessage,
     311             :                            "malformed SCRAM message (attribute \"%c\" expected)",
     312             :                            attr);
     313           0 :         return NULL;
     314             :     }
     315         308 :     begin++;
     316             : 
     317         308 :     if (*begin != '=')
     318             :     {
     319           0 :         libpq_append_error(errorMessage,
     320             :                            "malformed SCRAM message (expected character \"=\" for attribute \"%c\")",
     321             :                            attr);
     322           0 :         return NULL;
     323             :     }
     324         308 :     begin++;
     325             : 
     326         308 :     end = begin;
     327        9376 :     while (*end && *end != ',')
     328        9068 :         end++;
     329             : 
     330         308 :     if (*end)
     331             :     {
     332         160 :         *end = '\0';
     333         160 :         *input = end + 1;
     334             :     }
     335             :     else
     336         148 :         *input = end;
     337             : 
     338         308 :     return begin;
     339             : }
     340             : 
     341             : /*
     342             :  * Build the first exchange message sent by the client.
     343             :  */
     344             : static char *
     345          80 : build_client_first_message(fe_scram_state *state)
     346             : {
     347          80 :     PGconn     *conn = state->conn;
     348             :     char        raw_nonce[SCRAM_RAW_NONCE_LEN + 1];
     349             :     char       *result;
     350             :     int         channel_info_len;
     351             :     int         encoded_len;
     352             :     PQExpBufferData buf;
     353             : 
     354             :     /*
     355             :      * Generate a "raw" nonce.  This is converted to ASCII-printable form by
     356             :      * base64-encoding it.
     357             :      */
     358          80 :     if (!pg_strong_random(raw_nonce, SCRAM_RAW_NONCE_LEN))
     359             :     {
     360           0 :         libpq_append_conn_error(conn, "could not generate nonce");
     361           0 :         return NULL;
     362             :     }
     363             : 
     364          80 :     encoded_len = pg_b64_enc_len(SCRAM_RAW_NONCE_LEN);
     365             :     /* don't forget the zero-terminator */
     366          80 :     state->client_nonce = malloc(encoded_len + 1);
     367          80 :     if (state->client_nonce == NULL)
     368             :     {
     369           0 :         libpq_append_conn_error(conn, "out of memory");
     370           0 :         return NULL;
     371             :     }
     372          80 :     encoded_len = pg_b64_encode(raw_nonce, SCRAM_RAW_NONCE_LEN,
     373             :                                 state->client_nonce, encoded_len);
     374          80 :     if (encoded_len < 0)
     375             :     {
     376           0 :         libpq_append_conn_error(conn, "could not encode nonce");
     377           0 :         return NULL;
     378             :     }
     379          80 :     state->client_nonce[encoded_len] = '\0';
     380             : 
     381             :     /*
     382             :      * Generate message.  The username is left empty as the backend uses the
     383             :      * value provided by the startup packet.  Also, as this username is not
     384             :      * prepared with SASLprep, the message parsing would fail if it includes
     385             :      * '=' or ',' characters.
     386             :      */
     387             : 
     388          80 :     initPQExpBuffer(&buf);
     389             : 
     390             :     /*
     391             :      * First build the gs2-header with channel binding information.
     392             :      */
     393          80 :     if (strcmp(state->sasl_mechanism, SCRAM_SHA_256_PLUS_NAME) == 0)
     394             :     {
     395             :         Assert(conn->ssl_in_use);
     396          10 :         appendPQExpBufferStr(&buf, "p=tls-server-end-point");
     397             :     }
     398             : #ifdef USE_SSL
     399          70 :     else if (conn->channel_binding[0] != 'd' && /* disable */
     400          66 :              conn->ssl_in_use)
     401             :     {
     402             :         /*
     403             :          * Client supports channel binding, but thinks the server does not.
     404             :          */
     405           0 :         appendPQExpBufferChar(&buf, 'y');
     406             :     }
     407             : #endif
     408             :     else
     409             :     {
     410             :         /*
     411             :          * Client does not support channel binding, or has disabled it.
     412             :          */
     413          70 :         appendPQExpBufferChar(&buf, 'n');
     414             :     }
     415             : 
     416          80 :     if (PQExpBufferDataBroken(buf))
     417           0 :         goto oom_error;
     418             : 
     419          80 :     channel_info_len = buf.len;
     420             : 
     421          80 :     appendPQExpBuffer(&buf, ",,n=,r=%s", state->client_nonce);
     422          80 :     if (PQExpBufferDataBroken(buf))
     423           0 :         goto oom_error;
     424             : 
     425             :     /*
     426             :      * The first message content needs to be saved without channel binding
     427             :      * information.
     428             :      */
     429          80 :     state->client_first_message_bare = strdup(buf.data + channel_info_len + 2);
     430          80 :     if (!state->client_first_message_bare)
     431           0 :         goto oom_error;
     432             : 
     433          80 :     result = strdup(buf.data);
     434          80 :     if (result == NULL)
     435           0 :         goto oom_error;
     436             : 
     437          80 :     termPQExpBuffer(&buf);
     438          80 :     return result;
     439             : 
     440           0 : oom_error:
     441           0 :     termPQExpBuffer(&buf);
     442           0 :     libpq_append_conn_error(conn, "out of memory");
     443           0 :     return NULL;
     444             : }
     445             : 
     446             : /*
     447             :  * Build the final exchange message sent from the client.
     448             :  */
     449             : static char *
     450          80 : build_client_final_message(fe_scram_state *state)
     451             : {
     452             :     PQExpBufferData buf;
     453          80 :     PGconn     *conn = state->conn;
     454             :     uint8       client_proof[SCRAM_MAX_KEY_LEN];
     455             :     char       *result;
     456             :     int         encoded_len;
     457          80 :     const char *errstr = NULL;
     458             : 
     459          80 :     initPQExpBuffer(&buf);
     460             : 
     461             :     /*
     462             :      * Construct client-final-message-without-proof.  We need to remember it
     463             :      * for verifying the server proof in the final step of authentication.
     464             :      *
     465             :      * The channel binding flag handling (p/y/n) must be consistent with
     466             :      * build_client_first_message(), because the server will check that it's
     467             :      * the same flag both times.
     468             :      */
     469          80 :     if (strcmp(state->sasl_mechanism, SCRAM_SHA_256_PLUS_NAME) == 0)
     470             :     {
     471             : #ifdef USE_SSL
     472          10 :         char       *cbind_data = NULL;
     473          10 :         size_t      cbind_data_len = 0;
     474             :         size_t      cbind_header_len;
     475             :         char       *cbind_input;
     476             :         size_t      cbind_input_len;
     477             :         int         encoded_cbind_len;
     478             : 
     479             :         /* Fetch hash data of server's SSL certificate */
     480             :         cbind_data =
     481          10 :             pgtls_get_peer_certificate_hash(state->conn,
     482             :                                             &cbind_data_len);
     483          10 :         if (cbind_data == NULL)
     484             :         {
     485             :             /* error message is already set on error */
     486           0 :             termPQExpBuffer(&buf);
     487           0 :             return NULL;
     488             :         }
     489             : 
     490          10 :         appendPQExpBufferStr(&buf, "c=");
     491             : 
     492             :         /* p=type,, */
     493          10 :         cbind_header_len = strlen("p=tls-server-end-point,,");
     494          10 :         cbind_input_len = cbind_header_len + cbind_data_len;
     495          10 :         cbind_input = malloc(cbind_input_len);
     496          10 :         if (!cbind_input)
     497             :         {
     498           0 :             free(cbind_data);
     499           0 :             goto oom_error;
     500             :         }
     501          10 :         memcpy(cbind_input, "p=tls-server-end-point,,", cbind_header_len);
     502          10 :         memcpy(cbind_input + cbind_header_len, cbind_data, cbind_data_len);
     503             : 
     504          10 :         encoded_cbind_len = pg_b64_enc_len(cbind_input_len);
     505          10 :         if (!enlargePQExpBuffer(&buf, encoded_cbind_len))
     506             :         {
     507           0 :             free(cbind_data);
     508           0 :             free(cbind_input);
     509           0 :             goto oom_error;
     510             :         }
     511          10 :         encoded_cbind_len = pg_b64_encode(cbind_input, cbind_input_len,
     512          10 :                                           buf.data + buf.len,
     513             :                                           encoded_cbind_len);
     514          10 :         if (encoded_cbind_len < 0)
     515             :         {
     516           0 :             free(cbind_data);
     517           0 :             free(cbind_input);
     518           0 :             termPQExpBuffer(&buf);
     519           0 :             appendPQExpBufferStr(&conn->errorMessage,
     520             :                                  "could not encode cbind data for channel binding\n");
     521           0 :             return NULL;
     522             :         }
     523          10 :         buf.len += encoded_cbind_len;
     524          10 :         buf.data[buf.len] = '\0';
     525             : 
     526          10 :         free(cbind_data);
     527          10 :         free(cbind_input);
     528             : #else
     529             :         /*
     530             :          * Chose channel binding, but the SSL library doesn't support it.
     531             :          * Shouldn't happen.
     532             :          */
     533             :         termPQExpBuffer(&buf);
     534             :         appendPQExpBufferStr(&conn->errorMessage,
     535             :                              "channel binding not supported by this build\n");
     536             :         return NULL;
     537             : #endif                          /* USE_SSL */
     538             :     }
     539             : #ifdef USE_SSL
     540          70 :     else if (conn->channel_binding[0] != 'd' && /* disable */
     541          66 :              conn->ssl_in_use)
     542           0 :         appendPQExpBufferStr(&buf, "c=eSws"); /* base64 of "y,," */
     543             : #endif
     544             :     else
     545          70 :         appendPQExpBufferStr(&buf, "c=biws"); /* base64 of "n,," */
     546             : 
     547          80 :     if (PQExpBufferDataBroken(buf))
     548           0 :         goto oom_error;
     549             : 
     550          80 :     appendPQExpBuffer(&buf, ",r=%s", state->nonce);
     551          80 :     if (PQExpBufferDataBroken(buf))
     552           0 :         goto oom_error;
     553             : 
     554          80 :     state->client_final_message_without_proof = strdup(buf.data);
     555          80 :     if (state->client_final_message_without_proof == NULL)
     556           0 :         goto oom_error;
     557             : 
     558             :     /* Append proof to it, to form client-final-message. */
     559          80 :     if (!calculate_client_proof(state,
     560          80 :                                 state->client_final_message_without_proof,
     561             :                                 client_proof, &errstr))
     562             :     {
     563           0 :         termPQExpBuffer(&buf);
     564           0 :         libpq_append_conn_error(conn, "could not calculate client proof: %s", errstr);
     565           0 :         return NULL;
     566             :     }
     567             : 
     568          80 :     appendPQExpBufferStr(&buf, ",p=");
     569          80 :     encoded_len = pg_b64_enc_len(state->key_length);
     570          80 :     if (!enlargePQExpBuffer(&buf, encoded_len))
     571           0 :         goto oom_error;
     572          80 :     encoded_len = pg_b64_encode((char *) client_proof,
     573             :                                 state->key_length,
     574          80 :                                 buf.data + buf.len,
     575             :                                 encoded_len);
     576          80 :     if (encoded_len < 0)
     577             :     {
     578           0 :         termPQExpBuffer(&buf);
     579           0 :         libpq_append_conn_error(conn, "could not encode client proof");
     580           0 :         return NULL;
     581             :     }
     582          80 :     buf.len += encoded_len;
     583          80 :     buf.data[buf.len] = '\0';
     584             : 
     585          80 :     result = strdup(buf.data);
     586          80 :     if (result == NULL)
     587           0 :         goto oom_error;
     588             : 
     589          80 :     termPQExpBuffer(&buf);
     590          80 :     return result;
     591             : 
     592           0 : oom_error:
     593           0 :     termPQExpBuffer(&buf);
     594           0 :     libpq_append_conn_error(conn, "out of memory");
     595           0 :     return NULL;
     596             : }
     597             : 
     598             : /*
     599             :  * Read the first exchange message coming from the server.
     600             :  */
     601             : static bool
     602          80 : read_server_first_message(fe_scram_state *state, char *input)
     603             : {
     604          80 :     PGconn     *conn = state->conn;
     605             :     char       *iterations_str;
     606             :     char       *endptr;
     607             :     char       *encoded_salt;
     608             :     char       *nonce;
     609             :     int         decoded_salt_len;
     610             : 
     611          80 :     state->server_first_message = strdup(input);
     612          80 :     if (state->server_first_message == NULL)
     613             :     {
     614           0 :         libpq_append_conn_error(conn, "out of memory");
     615           0 :         return false;
     616             :     }
     617             : 
     618             :     /* parse the message */
     619          80 :     nonce = read_attr_value(&input, 'r',
     620             :                             &conn->errorMessage);
     621          80 :     if (nonce == NULL)
     622             :     {
     623             :         /* read_attr_value() has appended an error string */
     624           0 :         return false;
     625             :     }
     626             : 
     627             :     /* Verify immediately that the server used our part of the nonce */
     628          80 :     if (strlen(nonce) < strlen(state->client_nonce) ||
     629          80 :         memcmp(nonce, state->client_nonce, strlen(state->client_nonce)) != 0)
     630             :     {
     631           0 :         libpq_append_conn_error(conn, "invalid SCRAM response (nonce mismatch)");
     632           0 :         return false;
     633             :     }
     634             : 
     635          80 :     state->nonce = strdup(nonce);
     636          80 :     if (state->nonce == NULL)
     637             :     {
     638           0 :         libpq_append_conn_error(conn, "out of memory");
     639           0 :         return false;
     640             :     }
     641             : 
     642          80 :     encoded_salt = read_attr_value(&input, 's', &conn->errorMessage);
     643          80 :     if (encoded_salt == NULL)
     644             :     {
     645             :         /* read_attr_value() has appended an error string */
     646           0 :         return false;
     647             :     }
     648          80 :     decoded_salt_len = pg_b64_dec_len(strlen(encoded_salt));
     649          80 :     state->salt = malloc(decoded_salt_len);
     650          80 :     if (state->salt == NULL)
     651             :     {
     652           0 :         libpq_append_conn_error(conn, "out of memory");
     653           0 :         return false;
     654             :     }
     655         160 :     state->saltlen = pg_b64_decode(encoded_salt,
     656          80 :                                    strlen(encoded_salt),
     657             :                                    state->salt,
     658             :                                    decoded_salt_len);
     659          80 :     if (state->saltlen < 0)
     660             :     {
     661           0 :         libpq_append_conn_error(conn, "malformed SCRAM message (invalid salt)");
     662           0 :         return false;
     663             :     }
     664             : 
     665          80 :     iterations_str = read_attr_value(&input, 'i', &conn->errorMessage);
     666          80 :     if (iterations_str == NULL)
     667             :     {
     668             :         /* read_attr_value() has appended an error string */
     669           0 :         return false;
     670             :     }
     671          80 :     state->iterations = strtol(iterations_str, &endptr, 10);
     672          80 :     if (*endptr != '\0' || state->iterations < 1)
     673             :     {
     674           0 :         libpq_append_conn_error(conn, "malformed SCRAM message (invalid iteration count)");
     675           0 :         return false;
     676             :     }
     677             : 
     678          80 :     if (*input != '\0')
     679           0 :         libpq_append_conn_error(conn, "malformed SCRAM message (garbage at end of server-first-message)");
     680             : 
     681          80 :     return true;
     682             : }
     683             : 
     684             : /*
     685             :  * Read the final exchange message coming from the server.
     686             :  */
     687             : static bool
     688          68 : read_server_final_message(fe_scram_state *state, char *input)
     689             : {
     690          68 :     PGconn     *conn = state->conn;
     691             :     char       *encoded_server_signature;
     692             :     char       *decoded_server_signature;
     693             :     int         server_signature_len;
     694             : 
     695          68 :     state->server_final_message = strdup(input);
     696          68 :     if (!state->server_final_message)
     697             :     {
     698           0 :         libpq_append_conn_error(conn, "out of memory");
     699           0 :         return false;
     700             :     }
     701             : 
     702             :     /* Check for error result. */
     703          68 :     if (*input == 'e')
     704             :     {
     705           0 :         char       *errmsg = read_attr_value(&input, 'e',
     706             :                                              &conn->errorMessage);
     707             : 
     708           0 :         if (errmsg == NULL)
     709             :         {
     710             :             /* read_attr_value() has appended an error message */
     711           0 :             return false;
     712             :         }
     713           0 :         libpq_append_conn_error(conn, "error received from server in SCRAM exchange: %s",
     714             :                                 errmsg);
     715           0 :         return false;
     716             :     }
     717             : 
     718             :     /* Parse the message. */
     719          68 :     encoded_server_signature = read_attr_value(&input, 'v',
     720             :                                                &conn->errorMessage);
     721          68 :     if (encoded_server_signature == NULL)
     722             :     {
     723             :         /* read_attr_value() has appended an error message */
     724           0 :         return false;
     725             :     }
     726             : 
     727          68 :     if (*input != '\0')
     728           0 :         libpq_append_conn_error(conn, "malformed SCRAM message (garbage at end of server-final-message)");
     729             : 
     730          68 :     server_signature_len = pg_b64_dec_len(strlen(encoded_server_signature));
     731          68 :     decoded_server_signature = malloc(server_signature_len);
     732          68 :     if (!decoded_server_signature)
     733             :     {
     734           0 :         libpq_append_conn_error(conn, "out of memory");
     735           0 :         return false;
     736             :     }
     737             : 
     738          68 :     server_signature_len = pg_b64_decode(encoded_server_signature,
     739          68 :                                          strlen(encoded_server_signature),
     740             :                                          decoded_server_signature,
     741             :                                          server_signature_len);
     742          68 :     if (server_signature_len != state->key_length)
     743             :     {
     744           0 :         free(decoded_server_signature);
     745           0 :         libpq_append_conn_error(conn, "malformed SCRAM message (invalid server signature)");
     746           0 :         return false;
     747             :     }
     748          68 :     memcpy(state->ServerSignature, decoded_server_signature,
     749          68 :            state->key_length);
     750          68 :     free(decoded_server_signature);
     751             : 
     752          68 :     return true;
     753             : }
     754             : 
     755             : /*
     756             :  * Calculate the client proof, part of the final exchange message sent
     757             :  * by the client.  Returns true on success, false on failure with *errstr
     758             :  * pointing to a message about the error details.
     759             :  */
     760             : static bool
     761          80 : calculate_client_proof(fe_scram_state *state,
     762             :                        const char *client_final_message_without_proof,
     763             :                        uint8 *result, const char **errstr)
     764             : {
     765             :     uint8       StoredKey[SCRAM_MAX_KEY_LEN];
     766             :     uint8       ClientKey[SCRAM_MAX_KEY_LEN];
     767             :     uint8       ClientSignature[SCRAM_MAX_KEY_LEN];
     768             :     int         i;
     769             :     pg_hmac_ctx *ctx;
     770             : 
     771          80 :     ctx = pg_hmac_create(state->hash_type);
     772          80 :     if (ctx == NULL)
     773             :     {
     774           0 :         *errstr = pg_hmac_error(NULL);  /* returns OOM */
     775           0 :         return false;
     776             :     }
     777             : 
     778             :     /*
     779             :      * Calculate SaltedPassword, and store it in 'state' so that we can reuse
     780             :      * it later in verify_server_signature.
     781             :      */
     782          80 :     if (scram_SaltedPassword(state->password, state->hash_type,
     783          80 :                              state->key_length, state->salt, state->saltlen,
     784          80 :                              state->iterations, state->SaltedPassword,
     785          80 :                              errstr) < 0 ||
     786          80 :         scram_ClientKey(state->SaltedPassword, state->hash_type,
     787          80 :                         state->key_length, ClientKey, errstr) < 0 ||
     788          80 :         scram_H(ClientKey, state->hash_type, state->key_length,
     789             :                 StoredKey, errstr) < 0)
     790             :     {
     791             :         /* errstr is already filled here */
     792           0 :         pg_hmac_free(ctx);
     793           0 :         return false;
     794             :     }
     795             : 
     796         160 :     if (pg_hmac_init(ctx, StoredKey, state->key_length) < 0 ||
     797          80 :         pg_hmac_update(ctx,
     798          80 :                        (uint8 *) state->client_first_message_bare,
     799         160 :                        strlen(state->client_first_message_bare)) < 0 ||
     800         160 :         pg_hmac_update(ctx, (uint8 *) ",", 1) < 0 ||
     801          80 :         pg_hmac_update(ctx,
     802          80 :                        (uint8 *) state->server_first_message,
     803         160 :                        strlen(state->server_first_message)) < 0 ||
     804         160 :         pg_hmac_update(ctx, (uint8 *) ",", 1) < 0 ||
     805          80 :         pg_hmac_update(ctx,
     806             :                        (uint8 *) client_final_message_without_proof,
     807          80 :                        strlen(client_final_message_without_proof)) < 0 ||
     808          80 :         pg_hmac_final(ctx, ClientSignature, state->key_length) < 0)
     809             :     {
     810           0 :         *errstr = pg_hmac_error(ctx);
     811           0 :         pg_hmac_free(ctx);
     812           0 :         return false;
     813             :     }
     814             : 
     815        2640 :     for (i = 0; i < state->key_length; i++)
     816        2560 :         result[i] = ClientKey[i] ^ ClientSignature[i];
     817             : 
     818          80 :     pg_hmac_free(ctx);
     819          80 :     return true;
     820             : }
     821             : 
     822             : /*
     823             :  * Validate the server signature, received as part of the final exchange
     824             :  * message received from the server.  *match tracks if the server signature
     825             :  * matched or not. Returns true if the server signature got verified, and
     826             :  * false for a processing error with *errstr pointing to a message about the
     827             :  * error details.
     828             :  */
     829             : static bool
     830          68 : verify_server_signature(fe_scram_state *state, bool *match,
     831             :                         const char **errstr)
     832             : {
     833             :     uint8       expected_ServerSignature[SCRAM_MAX_KEY_LEN];
     834             :     uint8       ServerKey[SCRAM_MAX_KEY_LEN];
     835             :     pg_hmac_ctx *ctx;
     836             : 
     837          68 :     ctx = pg_hmac_create(state->hash_type);
     838          68 :     if (ctx == NULL)
     839             :     {
     840           0 :         *errstr = pg_hmac_error(NULL);  /* returns OOM */
     841           0 :         return false;
     842             :     }
     843             : 
     844          68 :     if (scram_ServerKey(state->SaltedPassword, state->hash_type,
     845             :                         state->key_length, ServerKey, errstr) < 0)
     846             :     {
     847             :         /* errstr is filled already */
     848           0 :         pg_hmac_free(ctx);
     849           0 :         return false;
     850             :     }
     851             : 
     852             :     /* calculate ServerSignature */
     853         136 :     if (pg_hmac_init(ctx, ServerKey, state->key_length) < 0 ||
     854          68 :         pg_hmac_update(ctx,
     855          68 :                        (uint8 *) state->client_first_message_bare,
     856         136 :                        strlen(state->client_first_message_bare)) < 0 ||
     857         136 :         pg_hmac_update(ctx, (uint8 *) ",", 1) < 0 ||
     858          68 :         pg_hmac_update(ctx,
     859          68 :                        (uint8 *) state->server_first_message,
     860         136 :                        strlen(state->server_first_message)) < 0 ||
     861         136 :         pg_hmac_update(ctx, (uint8 *) ",", 1) < 0 ||
     862          68 :         pg_hmac_update(ctx,
     863          68 :                        (uint8 *) state->client_final_message_without_proof,
     864         136 :                        strlen(state->client_final_message_without_proof)) < 0 ||
     865          68 :         pg_hmac_final(ctx, expected_ServerSignature,
     866          68 :                       state->key_length) < 0)
     867             :     {
     868           0 :         *errstr = pg_hmac_error(ctx);
     869           0 :         pg_hmac_free(ctx);
     870           0 :         return false;
     871             :     }
     872             : 
     873          68 :     pg_hmac_free(ctx);
     874             : 
     875             :     /* signature processed, so now check after it */
     876          68 :     if (memcmp(expected_ServerSignature, state->ServerSignature,
     877          68 :                state->key_length) != 0)
     878           0 :         *match = false;
     879             :     else
     880          68 :         *match = true;
     881             : 
     882          68 :     return true;
     883             : }
     884             : 
     885             : /*
     886             :  * Build a new SCRAM secret.
     887             :  *
     888             :  * On error, returns NULL and sets *errstr to point to a message about the
     889             :  * error details.
     890             :  */
     891             : char *
     892           2 : pg_fe_scram_build_secret(const char *password, int iterations, const char **errstr)
     893             : {
     894             :     char       *prep_password;
     895             :     pg_saslprep_rc rc;
     896             :     char        saltbuf[SCRAM_DEFAULT_SALT_LEN];
     897             :     char       *result;
     898             : 
     899             :     /*
     900             :      * Normalize the password with SASLprep.  If that doesn't work, because
     901             :      * the password isn't valid UTF-8 or contains prohibited characters, just
     902             :      * proceed with the original password.  (See comments at the top of
     903             :      * auth-scram.c.)
     904             :      */
     905           2 :     rc = pg_saslprep(password, &prep_password);
     906           2 :     if (rc == SASLPREP_OOM)
     907             :     {
     908           0 :         *errstr = libpq_gettext("out of memory");
     909           0 :         return NULL;
     910             :     }
     911           2 :     if (rc == SASLPREP_SUCCESS)
     912           2 :         password = (const char *) prep_password;
     913             : 
     914             :     /* Generate a random salt */
     915           2 :     if (!pg_strong_random(saltbuf, SCRAM_DEFAULT_SALT_LEN))
     916             :     {
     917           0 :         *errstr = libpq_gettext("could not generate random salt");
     918           0 :         free(prep_password);
     919           0 :         return NULL;
     920             :     }
     921             : 
     922           2 :     result = scram_build_secret(PG_SHA256, SCRAM_SHA_256_KEY_LEN, saltbuf,
     923             :                                 SCRAM_DEFAULT_SALT_LEN,
     924             :                                 iterations, password,
     925             :                                 errstr);
     926             : 
     927           2 :     free(prep_password);
     928             : 
     929           2 :     return result;
     930             : }

Generated by: LCOV version 1.14