LCOV - code coverage report
Current view: top level - src/interfaces/libpq - fe-auth-scram.c (source / functions) Hit Total Coverage
Test: PostgreSQL 15devel Lines: 247 400 61.8 %
Date: 2021-12-09 03:08:47 Functions: 11 12 91.7 %
Legend: Lines: hit not hit

          Line data    Source code
       1             : /*-------------------------------------------------------------------------
       2             :  *
       3             :  * fe-auth-scram.c
       4             :  *     The front-end (client) implementation of SCRAM authentication.
       5             :  *
       6             :  * Portions Copyright (c) 1996-2021, PostgreSQL Global Development Group
       7             :  * Portions Copyright (c) 1994, Regents of the University of California
       8             :  *
       9             :  * IDENTIFICATION
      10             :  *    src/interfaces/libpq/fe-auth-scram.c
      11             :  *
      12             :  *-------------------------------------------------------------------------
      13             :  */
      14             : 
      15             : #include "postgres_fe.h"
      16             : 
      17             : #include "common/base64.h"
      18             : #include "common/hmac.h"
      19             : #include "common/saslprep.h"
      20             : #include "common/scram-common.h"
      21             : #include "fe-auth.h"
      22             : 
      23             : 
      24             : /* The exported SCRAM callback mechanism. */
      25             : static void *scram_init(PGconn *conn, const char *password,
      26             :                         const char *sasl_mechanism);
      27             : static void scram_exchange(void *opaq, char *input, int inputlen,
      28             :                            char **output, int *outputlen,
      29             :                            bool *done, bool *success);
      30             : static bool scram_channel_bound(void *opaq);
      31             : static void scram_free(void *opaq);
      32             : 
      33             : const pg_fe_sasl_mech pg_scram_mech = {
      34             :     scram_init,
      35             :     scram_exchange,
      36             :     scram_channel_bound,
      37             :     scram_free
      38             : };
      39             : 
      40             : /*
      41             :  * Status of exchange messages used for SCRAM authentication via the
      42             :  * SASL protocol.
      43             :  */
      44             : typedef enum
      45             : {
      46             :     FE_SCRAM_INIT,
      47             :     FE_SCRAM_NONCE_SENT,
      48             :     FE_SCRAM_PROOF_SENT,
      49             :     FE_SCRAM_FINISHED
      50             : } fe_scram_state_enum;
      51             : 
      52             : typedef struct
      53             : {
      54             :     fe_scram_state_enum state;
      55             : 
      56             :     /* These are supplied by the user */
      57             :     PGconn     *conn;
      58             :     char       *password;
      59             :     char       *sasl_mechanism;
      60             : 
      61             :     /* We construct these */
      62             :     uint8       SaltedPassword[SCRAM_KEY_LEN];
      63             :     char       *client_nonce;
      64             :     char       *client_first_message_bare;
      65             :     char       *client_final_message_without_proof;
      66             : 
      67             :     /* These come from the server-first message */
      68             :     char       *server_first_message;
      69             :     char       *salt;
      70             :     int         saltlen;
      71             :     int         iterations;
      72             :     char       *nonce;
      73             : 
      74             :     /* These come from the server-final message */
      75             :     char       *server_final_message;
      76             :     char        ServerSignature[SCRAM_KEY_LEN];
      77             : } fe_scram_state;
      78             : 
      79             : static bool read_server_first_message(fe_scram_state *state, char *input);
      80             : static bool read_server_final_message(fe_scram_state *state, char *input);
      81             : static char *build_client_first_message(fe_scram_state *state);
      82             : static char *build_client_final_message(fe_scram_state *state);
      83             : static bool verify_server_signature(fe_scram_state *state, bool *match);
      84             : static bool calculate_client_proof(fe_scram_state *state,
      85             :                                    const char *client_final_message_without_proof,
      86             :                                    uint8 *result);
      87             : 
      88             : /*
      89             :  * Initialize SCRAM exchange status.
      90             :  */
      91             : static void *
      92          40 : scram_init(PGconn *conn,
      93             :            const char *password,
      94             :            const char *sasl_mechanism)
      95             : {
      96             :     fe_scram_state *state;
      97             :     char       *prep_password;
      98             :     pg_saslprep_rc rc;
      99             : 
     100             :     Assert(sasl_mechanism != NULL);
     101             : 
     102          40 :     state = (fe_scram_state *) malloc(sizeof(fe_scram_state));
     103          40 :     if (!state)
     104           0 :         return NULL;
     105          40 :     memset(state, 0, sizeof(fe_scram_state));
     106          40 :     state->conn = conn;
     107          40 :     state->state = FE_SCRAM_INIT;
     108          40 :     state->sasl_mechanism = strdup(sasl_mechanism);
     109             : 
     110          40 :     if (!state->sasl_mechanism)
     111             :     {
     112           0 :         free(state);
     113           0 :         return NULL;
     114             :     }
     115             : 
     116             :     /* Normalize the password with SASLprep, if possible */
     117          40 :     rc = pg_saslprep(password, &prep_password);
     118          40 :     if (rc == SASLPREP_OOM)
     119             :     {
     120           0 :         free(state->sasl_mechanism);
     121           0 :         free(state);
     122           0 :         return NULL;
     123             :     }
     124          40 :     if (rc != SASLPREP_SUCCESS)
     125             :     {
     126           4 :         prep_password = strdup(password);
     127           4 :         if (!prep_password)
     128             :         {
     129           0 :             free(state->sasl_mechanism);
     130           0 :             free(state);
     131           0 :             return NULL;
     132             :         }
     133             :     }
     134          40 :     state->password = prep_password;
     135             : 
     136          40 :     return state;
     137             : }
     138             : 
     139             : /*
     140             :  * Return true if channel binding was employed and the SCRAM exchange
     141             :  * completed. This should be used after a successful exchange to determine
     142             :  * whether the server authenticated itself to the client.
     143             :  *
     144             :  * Note that the caller must also ensure that the exchange was actually
     145             :  * successful.
     146             :  */
     147             : static bool
     148           2 : scram_channel_bound(void *opaq)
     149             : {
     150           2 :     fe_scram_state *state = (fe_scram_state *) opaq;
     151             : 
     152             :     /* no SCRAM exchange done */
     153           2 :     if (state == NULL)
     154           0 :         return false;
     155             : 
     156             :     /* SCRAM exchange not completed */
     157           2 :     if (state->state != FE_SCRAM_FINISHED)
     158           0 :         return false;
     159             : 
     160             :     /* channel binding mechanism not used */
     161           2 :     if (strcmp(state->sasl_mechanism, SCRAM_SHA_256_PLUS_NAME) != 0)
     162           0 :         return false;
     163             : 
     164             :     /* all clear! */
     165           2 :     return true;
     166             : }
     167             : 
     168             : /*
     169             :  * Free SCRAM exchange status
     170             :  */
     171             : static void
     172          40 : scram_free(void *opaq)
     173             : {
     174          40 :     fe_scram_state *state = (fe_scram_state *) opaq;
     175             : 
     176          40 :     if (state->password)
     177          40 :         free(state->password);
     178          40 :     if (state->sasl_mechanism)
     179          40 :         free(state->sasl_mechanism);
     180             : 
     181             :     /* client messages */
     182          40 :     if (state->client_nonce)
     183          40 :         free(state->client_nonce);
     184          40 :     if (state->client_first_message_bare)
     185          40 :         free(state->client_first_message_bare);
     186          40 :     if (state->client_final_message_without_proof)
     187          40 :         free(state->client_final_message_without_proof);
     188             : 
     189             :     /* first message from server */
     190          40 :     if (state->server_first_message)
     191          40 :         free(state->server_first_message);
     192          40 :     if (state->salt)
     193          40 :         free(state->salt);
     194          40 :     if (state->nonce)
     195          40 :         free(state->nonce);
     196             : 
     197             :     /* final message from server */
     198          40 :     if (state->server_final_message)
     199          28 :         free(state->server_final_message);
     200             : 
     201          40 :     free(state);
     202          40 : }
     203             : 
     204             : /*
     205             :  * Exchange a SCRAM message with backend.
     206             :  */
     207             : static void
     208         108 : scram_exchange(void *opaq, char *input, int inputlen,
     209             :                char **output, int *outputlen,
     210             :                bool *done, bool *success)
     211             : {
     212         108 :     fe_scram_state *state = (fe_scram_state *) opaq;
     213         108 :     PGconn     *conn = state->conn;
     214             : 
     215         108 :     *done = false;
     216         108 :     *success = false;
     217         108 :     *output = NULL;
     218         108 :     *outputlen = 0;
     219             : 
     220             :     /*
     221             :      * Check that the input length agrees with the string length of the input.
     222             :      * We can ignore inputlen after this.
     223             :      */
     224         108 :     if (state->state != FE_SCRAM_INIT)
     225             :     {
     226          68 :         if (inputlen == 0)
     227             :         {
     228           0 :             appendPQExpBufferStr(&conn->errorMessage,
     229           0 :                                  libpq_gettext("malformed SCRAM message (empty message)\n"));
     230           0 :             goto error;
     231             :         }
     232          68 :         if (inputlen != strlen(input))
     233             :         {
     234           0 :             appendPQExpBufferStr(&conn->errorMessage,
     235           0 :                                  libpq_gettext("malformed SCRAM message (length mismatch)\n"));
     236           0 :             goto error;
     237             :         }
     238             :     }
     239             : 
     240         108 :     switch (state->state)
     241             :     {
     242          40 :         case FE_SCRAM_INIT:
     243             :             /* Begin the SCRAM handshake, by sending client nonce */
     244          40 :             *output = build_client_first_message(state);
     245          40 :             if (*output == NULL)
     246           0 :                 goto error;
     247             : 
     248          40 :             *outputlen = strlen(*output);
     249          40 :             *done = false;
     250          40 :             state->state = FE_SCRAM_NONCE_SENT;
     251          40 :             break;
     252             : 
     253          40 :         case FE_SCRAM_NONCE_SENT:
     254             :             /* Receive salt and server nonce, send response. */
     255          40 :             if (!read_server_first_message(state, input))
     256           0 :                 goto error;
     257             : 
     258          40 :             *output = build_client_final_message(state);
     259          40 :             if (*output == NULL)
     260           0 :                 goto error;
     261             : 
     262          40 :             *outputlen = strlen(*output);
     263          40 :             *done = false;
     264          40 :             state->state = FE_SCRAM_PROOF_SENT;
     265          40 :             break;
     266             : 
     267          28 :         case FE_SCRAM_PROOF_SENT:
     268             :             /* Receive server signature */
     269          28 :             if (!read_server_final_message(state, input))
     270           0 :                 goto error;
     271             : 
     272             :             /*
     273             :              * Verify server signature, to make sure we're talking to the
     274             :              * genuine server.
     275             :              */
     276          28 :             if (!verify_server_signature(state, success))
     277             :             {
     278           0 :                 appendPQExpBufferStr(&conn->errorMessage,
     279           0 :                                      libpq_gettext("could not verify server signature\n"));
     280           0 :                 goto error;
     281             :             }
     282             : 
     283          28 :             if (!*success)
     284             :             {
     285           0 :                 appendPQExpBufferStr(&conn->errorMessage,
     286           0 :                                      libpq_gettext("incorrect server signature\n"));
     287             :             }
     288          28 :             *done = true;
     289          28 :             state->state = FE_SCRAM_FINISHED;
     290          28 :             break;
     291             : 
     292           0 :         default:
     293             :             /* shouldn't happen */
     294           0 :             appendPQExpBufferStr(&conn->errorMessage,
     295           0 :                                  libpq_gettext("invalid SCRAM exchange state\n"));
     296           0 :             goto error;
     297             :     }
     298         108 :     return;
     299             : 
     300           0 : error:
     301           0 :     *done = true;
     302           0 :     *success = false;
     303             : }
     304             : 
     305             : /*
     306             :  * Read value for an attribute part of a SCRAM message.
     307             :  *
     308             :  * The buffer at **input is destructively modified, and *input is
     309             :  * advanced over the "attr=value" string and any following comma.
     310             :  *
     311             :  * On failure, append an error message to *errorMessage and return NULL.
     312             :  */
     313             : static char *
     314         148 : read_attr_value(char **input, char attr, PQExpBuffer errorMessage)
     315             : {
     316         148 :     char       *begin = *input;
     317             :     char       *end;
     318             : 
     319         148 :     if (*begin != attr)
     320             :     {
     321           0 :         appendPQExpBuffer(errorMessage,
     322           0 :                           libpq_gettext("malformed SCRAM message (attribute \"%c\" expected)\n"),
     323             :                           attr);
     324           0 :         return NULL;
     325             :     }
     326         148 :     begin++;
     327             : 
     328         148 :     if (*begin != '=')
     329             :     {
     330           0 :         appendPQExpBuffer(errorMessage,
     331           0 :                           libpq_gettext("malformed SCRAM message (expected character \"=\" for attribute \"%c\")\n"),
     332             :                           attr);
     333           0 :         return NULL;
     334             :     }
     335         148 :     begin++;
     336             : 
     337         148 :     end = begin;
     338        4420 :     while (*end && *end != ',')
     339        4272 :         end++;
     340             : 
     341         148 :     if (*end)
     342             :     {
     343          80 :         *end = '\0';
     344          80 :         *input = end + 1;
     345             :     }
     346             :     else
     347          68 :         *input = end;
     348             : 
     349         148 :     return begin;
     350             : }
     351             : 
     352             : /*
     353             :  * Build the first exchange message sent by the client.
     354             :  */
     355             : static char *
     356          40 : build_client_first_message(fe_scram_state *state)
     357             : {
     358          40 :     PGconn     *conn = state->conn;
     359             :     char        raw_nonce[SCRAM_RAW_NONCE_LEN + 1];
     360             :     char       *result;
     361             :     int         channel_info_len;
     362             :     int         encoded_len;
     363             :     PQExpBufferData buf;
     364             : 
     365             :     /*
     366             :      * Generate a "raw" nonce.  This is converted to ASCII-printable form by
     367             :      * base64-encoding it.
     368             :      */
     369          40 :     if (!pg_strong_random(raw_nonce, SCRAM_RAW_NONCE_LEN))
     370             :     {
     371           0 :         appendPQExpBufferStr(&conn->errorMessage,
     372           0 :                              libpq_gettext("could not generate nonce\n"));
     373           0 :         return NULL;
     374             :     }
     375             : 
     376          40 :     encoded_len = pg_b64_enc_len(SCRAM_RAW_NONCE_LEN);
     377             :     /* don't forget the zero-terminator */
     378          40 :     state->client_nonce = malloc(encoded_len + 1);
     379          40 :     if (state->client_nonce == NULL)
     380             :     {
     381           0 :         appendPQExpBufferStr(&conn->errorMessage,
     382           0 :                              libpq_gettext("out of memory\n"));
     383           0 :         return NULL;
     384             :     }
     385          40 :     encoded_len = pg_b64_encode(raw_nonce, SCRAM_RAW_NONCE_LEN,
     386             :                                 state->client_nonce, encoded_len);
     387          40 :     if (encoded_len < 0)
     388             :     {
     389           0 :         appendPQExpBufferStr(&conn->errorMessage,
     390           0 :                              libpq_gettext("could not encode nonce\n"));
     391           0 :         return NULL;
     392             :     }
     393          40 :     state->client_nonce[encoded_len] = '\0';
     394             : 
     395             :     /*
     396             :      * Generate message.  The username is left empty as the backend uses the
     397             :      * value provided by the startup packet.  Also, as this username is not
     398             :      * prepared with SASLprep, the message parsing would fail if it includes
     399             :      * '=' or ',' characters.
     400             :      */
     401             : 
     402          40 :     initPQExpBuffer(&buf);
     403             : 
     404             :     /*
     405             :      * First build the gs2-header with channel binding information.
     406             :      */
     407          40 :     if (strcmp(state->sasl_mechanism, SCRAM_SHA_256_PLUS_NAME) == 0)
     408             :     {
     409             :         Assert(conn->ssl_in_use);
     410           6 :         appendPQExpBufferStr(&buf, "p=tls-server-end-point");
     411             :     }
     412             : #ifdef HAVE_PGTLS_GET_PEER_CERTIFICATE_HASH
     413          34 :     else if (conn->channel_binding[0] != 'd' && /* disable */
     414          32 :              conn->ssl_in_use)
     415             :     {
     416             :         /*
     417             :          * Client supports channel binding, but thinks the server does not.
     418             :          */
     419           0 :         appendPQExpBufferChar(&buf, 'y');
     420             :     }
     421             : #endif
     422             :     else
     423             :     {
     424             :         /*
     425             :          * Client does not support channel binding, or has disabled it.
     426             :          */
     427          34 :         appendPQExpBufferChar(&buf, 'n');
     428             :     }
     429             : 
     430          40 :     if (PQExpBufferDataBroken(buf))
     431           0 :         goto oom_error;
     432             : 
     433          40 :     channel_info_len = buf.len;
     434             : 
     435          40 :     appendPQExpBuffer(&buf, ",,n=,r=%s", state->client_nonce);
     436          40 :     if (PQExpBufferDataBroken(buf))
     437           0 :         goto oom_error;
     438             : 
     439             :     /*
     440             :      * The first message content needs to be saved without channel binding
     441             :      * information.
     442             :      */
     443          40 :     state->client_first_message_bare = strdup(buf.data + channel_info_len + 2);
     444          40 :     if (!state->client_first_message_bare)
     445           0 :         goto oom_error;
     446             : 
     447          40 :     result = strdup(buf.data);
     448          40 :     if (result == NULL)
     449           0 :         goto oom_error;
     450             : 
     451          40 :     termPQExpBuffer(&buf);
     452          40 :     return result;
     453             : 
     454           0 : oom_error:
     455           0 :     termPQExpBuffer(&buf);
     456           0 :     appendPQExpBufferStr(&conn->errorMessage,
     457           0 :                          libpq_gettext("out of memory\n"));
     458           0 :     return NULL;
     459             : }
     460             : 
     461             : /*
     462             :  * Build the final exchange message sent from the client.
     463             :  */
     464             : static char *
     465          40 : build_client_final_message(fe_scram_state *state)
     466             : {
     467             :     PQExpBufferData buf;
     468          40 :     PGconn     *conn = state->conn;
     469             :     uint8       client_proof[SCRAM_KEY_LEN];
     470             :     char       *result;
     471             :     int         encoded_len;
     472             : 
     473          40 :     initPQExpBuffer(&buf);
     474             : 
     475             :     /*
     476             :      * Construct client-final-message-without-proof.  We need to remember it
     477             :      * for verifying the server proof in the final step of authentication.
     478             :      *
     479             :      * The channel binding flag handling (p/y/n) must be consistent with
     480             :      * build_client_first_message(), because the server will check that it's
     481             :      * the same flag both times.
     482             :      */
     483          40 :     if (strcmp(state->sasl_mechanism, SCRAM_SHA_256_PLUS_NAME) == 0)
     484             :     {
     485             : #ifdef HAVE_PGTLS_GET_PEER_CERTIFICATE_HASH
     486           6 :         char       *cbind_data = NULL;
     487           6 :         size_t      cbind_data_len = 0;
     488             :         size_t      cbind_header_len;
     489             :         char       *cbind_input;
     490             :         size_t      cbind_input_len;
     491             :         int         encoded_cbind_len;
     492             : 
     493             :         /* Fetch hash data of server's SSL certificate */
     494             :         cbind_data =
     495           6 :             pgtls_get_peer_certificate_hash(state->conn,
     496             :                                             &cbind_data_len);
     497           6 :         if (cbind_data == NULL)
     498             :         {
     499             :             /* error message is already set on error */
     500           0 :             termPQExpBuffer(&buf);
     501           0 :             return NULL;
     502             :         }
     503             : 
     504           6 :         appendPQExpBufferStr(&buf, "c=");
     505             : 
     506             :         /* p=type,, */
     507           6 :         cbind_header_len = strlen("p=tls-server-end-point,,");
     508           6 :         cbind_input_len = cbind_header_len + cbind_data_len;
     509           6 :         cbind_input = malloc(cbind_input_len);
     510           6 :         if (!cbind_input)
     511             :         {
     512           0 :             free(cbind_data);
     513           0 :             goto oom_error;
     514             :         }
     515           6 :         memcpy(cbind_input, "p=tls-server-end-point,,", cbind_header_len);
     516           6 :         memcpy(cbind_input + cbind_header_len, cbind_data, cbind_data_len);
     517             : 
     518           6 :         encoded_cbind_len = pg_b64_enc_len(cbind_input_len);
     519           6 :         if (!enlargePQExpBuffer(&buf, encoded_cbind_len))
     520             :         {
     521           0 :             free(cbind_data);
     522           0 :             free(cbind_input);
     523           0 :             goto oom_error;
     524             :         }
     525           6 :         encoded_cbind_len = pg_b64_encode(cbind_input, cbind_input_len,
     526           6 :                                           buf.data + buf.len,
     527             :                                           encoded_cbind_len);
     528           6 :         if (encoded_cbind_len < 0)
     529             :         {
     530           0 :             free(cbind_data);
     531           0 :             free(cbind_input);
     532           0 :             termPQExpBuffer(&buf);
     533           0 :             appendPQExpBufferStr(&conn->errorMessage,
     534             :                                  "could not encode cbind data for channel binding\n");
     535           0 :             return NULL;
     536             :         }
     537           6 :         buf.len += encoded_cbind_len;
     538           6 :         buf.data[buf.len] = '\0';
     539             : 
     540           6 :         free(cbind_data);
     541           6 :         free(cbind_input);
     542             : #else
     543             :         /*
     544             :          * Chose channel binding, but the SSL library doesn't support it.
     545             :          * Shouldn't happen.
     546             :          */
     547             :         termPQExpBuffer(&buf);
     548             :         appendPQExpBufferStr(&conn->errorMessage,
     549             :                              "channel binding not supported by this build\n");
     550             :         return NULL;
     551             : #endif                          /* HAVE_PGTLS_GET_PEER_CERTIFICATE_HASH */
     552             :     }
     553             : #ifdef HAVE_PGTLS_GET_PEER_CERTIFICATE_HASH
     554          34 :     else if (conn->channel_binding[0] != 'd' && /* disable */
     555          32 :              conn->ssl_in_use)
     556           0 :         appendPQExpBufferStr(&buf, "c=eSws"); /* base64 of "y,," */
     557             : #endif
     558             :     else
     559          34 :         appendPQExpBufferStr(&buf, "c=biws"); /* base64 of "n,," */
     560             : 
     561          40 :     if (PQExpBufferDataBroken(buf))
     562           0 :         goto oom_error;
     563             : 
     564          40 :     appendPQExpBuffer(&buf, ",r=%s", state->nonce);
     565          40 :     if (PQExpBufferDataBroken(buf))
     566           0 :         goto oom_error;
     567             : 
     568          40 :     state->client_final_message_without_proof = strdup(buf.data);
     569          40 :     if (state->client_final_message_without_proof == NULL)
     570           0 :         goto oom_error;
     571             : 
     572             :     /* Append proof to it, to form client-final-message. */
     573          40 :     if (!calculate_client_proof(state,
     574          40 :                                 state->client_final_message_without_proof,
     575             :                                 client_proof))
     576             :     {
     577           0 :         termPQExpBuffer(&buf);
     578           0 :         appendPQExpBufferStr(&conn->errorMessage,
     579           0 :                              libpq_gettext("could not calculate client proof\n"));
     580           0 :         return NULL;
     581             :     }
     582             : 
     583          40 :     appendPQExpBufferStr(&buf, ",p=");
     584          40 :     encoded_len = pg_b64_enc_len(SCRAM_KEY_LEN);
     585          40 :     if (!enlargePQExpBuffer(&buf, encoded_len))
     586           0 :         goto oom_error;
     587          40 :     encoded_len = pg_b64_encode((char *) client_proof,
     588             :                                 SCRAM_KEY_LEN,
     589          40 :                                 buf.data + buf.len,
     590             :                                 encoded_len);
     591          40 :     if (encoded_len < 0)
     592             :     {
     593           0 :         termPQExpBuffer(&buf);
     594           0 :         appendPQExpBufferStr(&conn->errorMessage,
     595           0 :                              libpq_gettext("could not encode client proof\n"));
     596           0 :         return NULL;
     597             :     }
     598          40 :     buf.len += encoded_len;
     599          40 :     buf.data[buf.len] = '\0';
     600             : 
     601          40 :     result = strdup(buf.data);
     602          40 :     if (result == NULL)
     603           0 :         goto oom_error;
     604             : 
     605          40 :     termPQExpBuffer(&buf);
     606          40 :     return result;
     607             : 
     608           0 : oom_error:
     609           0 :     termPQExpBuffer(&buf);
     610           0 :     appendPQExpBufferStr(&conn->errorMessage,
     611           0 :                          libpq_gettext("out of memory\n"));
     612           0 :     return NULL;
     613             : }
     614             : 
     615             : /*
     616             :  * Read the first exchange message coming from the server.
     617             :  */
     618             : static bool
     619          40 : read_server_first_message(fe_scram_state *state, char *input)
     620             : {
     621          40 :     PGconn     *conn = state->conn;
     622             :     char       *iterations_str;
     623             :     char       *endptr;
     624             :     char       *encoded_salt;
     625             :     char       *nonce;
     626             :     int         decoded_salt_len;
     627             : 
     628          40 :     state->server_first_message = strdup(input);
     629          40 :     if (state->server_first_message == NULL)
     630             :     {
     631           0 :         appendPQExpBufferStr(&conn->errorMessage,
     632           0 :                              libpq_gettext("out of memory\n"));
     633           0 :         return false;
     634             :     }
     635             : 
     636             :     /* parse the message */
     637          40 :     nonce = read_attr_value(&input, 'r',
     638             :                             &conn->errorMessage);
     639          40 :     if (nonce == NULL)
     640             :     {
     641             :         /* read_attr_value() has appended an error string */
     642           0 :         return false;
     643             :     }
     644             : 
     645             :     /* Verify immediately that the server used our part of the nonce */
     646          40 :     if (strlen(nonce) < strlen(state->client_nonce) ||
     647          40 :         memcmp(nonce, state->client_nonce, strlen(state->client_nonce)) != 0)
     648             :     {
     649           0 :         appendPQExpBufferStr(&conn->errorMessage,
     650           0 :                              libpq_gettext("invalid SCRAM response (nonce mismatch)\n"));
     651           0 :         return false;
     652             :     }
     653             : 
     654          40 :     state->nonce = strdup(nonce);
     655          40 :     if (state->nonce == NULL)
     656             :     {
     657           0 :         appendPQExpBufferStr(&conn->errorMessage,
     658           0 :                              libpq_gettext("out of memory\n"));
     659           0 :         return false;
     660             :     }
     661             : 
     662          40 :     encoded_salt = read_attr_value(&input, 's', &conn->errorMessage);
     663          40 :     if (encoded_salt == NULL)
     664             :     {
     665             :         /* read_attr_value() has appended an error string */
     666           0 :         return false;
     667             :     }
     668          40 :     decoded_salt_len = pg_b64_dec_len(strlen(encoded_salt));
     669          40 :     state->salt = malloc(decoded_salt_len);
     670          40 :     if (state->salt == NULL)
     671             :     {
     672           0 :         appendPQExpBufferStr(&conn->errorMessage,
     673           0 :                              libpq_gettext("out of memory\n"));
     674           0 :         return false;
     675             :     }
     676          80 :     state->saltlen = pg_b64_decode(encoded_salt,
     677          40 :                                    strlen(encoded_salt),
     678             :                                    state->salt,
     679             :                                    decoded_salt_len);
     680          40 :     if (state->saltlen < 0)
     681             :     {
     682           0 :         appendPQExpBufferStr(&conn->errorMessage,
     683           0 :                              libpq_gettext("malformed SCRAM message (invalid salt)\n"));
     684           0 :         return false;
     685             :     }
     686             : 
     687          40 :     iterations_str = read_attr_value(&input, 'i', &conn->errorMessage);
     688          40 :     if (iterations_str == NULL)
     689             :     {
     690             :         /* read_attr_value() has appended an error string */
     691           0 :         return false;
     692             :     }
     693          40 :     state->iterations = strtol(iterations_str, &endptr, 10);
     694          40 :     if (*endptr != '\0' || state->iterations < 1)
     695             :     {
     696           0 :         appendPQExpBufferStr(&conn->errorMessage,
     697           0 :                              libpq_gettext("malformed SCRAM message (invalid iteration count)\n"));
     698           0 :         return false;
     699             :     }
     700             : 
     701          40 :     if (*input != '\0')
     702           0 :         appendPQExpBufferStr(&conn->errorMessage,
     703           0 :                              libpq_gettext("malformed SCRAM message (garbage at end of server-first-message)\n"));
     704             : 
     705          40 :     return true;
     706             : }
     707             : 
     708             : /*
     709             :  * Read the final exchange message coming from the server.
     710             :  */
     711             : static bool
     712          28 : read_server_final_message(fe_scram_state *state, char *input)
     713             : {
     714          28 :     PGconn     *conn = state->conn;
     715             :     char       *encoded_server_signature;
     716             :     char       *decoded_server_signature;
     717             :     int         server_signature_len;
     718             : 
     719          28 :     state->server_final_message = strdup(input);
     720          28 :     if (!state->server_final_message)
     721             :     {
     722           0 :         appendPQExpBufferStr(&conn->errorMessage,
     723           0 :                              libpq_gettext("out of memory\n"));
     724           0 :         return false;
     725             :     }
     726             : 
     727             :     /* Check for error result. */
     728          28 :     if (*input == 'e')
     729             :     {
     730           0 :         char       *errmsg = read_attr_value(&input, 'e',
     731             :                                              &conn->errorMessage);
     732             : 
     733           0 :         if (errmsg == NULL)
     734             :         {
     735             :             /* read_attr_value() has appended an error message */
     736           0 :             return false;
     737             :         }
     738           0 :         appendPQExpBuffer(&conn->errorMessage,
     739           0 :                           libpq_gettext("error received from server in SCRAM exchange: %s\n"),
     740             :                           errmsg);
     741           0 :         return false;
     742             :     }
     743             : 
     744             :     /* Parse the message. */
     745          28 :     encoded_server_signature = read_attr_value(&input, 'v',
     746             :                                                &conn->errorMessage);
     747          28 :     if (encoded_server_signature == NULL)
     748             :     {
     749             :         /* read_attr_value() has appended an error message */
     750           0 :         return false;
     751             :     }
     752             : 
     753          28 :     if (*input != '\0')
     754           0 :         appendPQExpBufferStr(&conn->errorMessage,
     755           0 :                              libpq_gettext("malformed SCRAM message (garbage at end of server-final-message)\n"));
     756             : 
     757          28 :     server_signature_len = pg_b64_dec_len(strlen(encoded_server_signature));
     758          28 :     decoded_server_signature = malloc(server_signature_len);
     759          28 :     if (!decoded_server_signature)
     760             :     {
     761           0 :         appendPQExpBufferStr(&conn->errorMessage,
     762           0 :                              libpq_gettext("out of memory\n"));
     763           0 :         return false;
     764             :     }
     765             : 
     766          28 :     server_signature_len = pg_b64_decode(encoded_server_signature,
     767          28 :                                          strlen(encoded_server_signature),
     768             :                                          decoded_server_signature,
     769             :                                          server_signature_len);
     770          28 :     if (server_signature_len != SCRAM_KEY_LEN)
     771             :     {
     772           0 :         free(decoded_server_signature);
     773           0 :         appendPQExpBufferStr(&conn->errorMessage,
     774           0 :                              libpq_gettext("malformed SCRAM message (invalid server signature)\n"));
     775           0 :         return false;
     776             :     }
     777          28 :     memcpy(state->ServerSignature, decoded_server_signature, SCRAM_KEY_LEN);
     778          28 :     free(decoded_server_signature);
     779             : 
     780          28 :     return true;
     781             : }
     782             : 
     783             : /*
     784             :  * Calculate the client proof, part of the final exchange message sent
     785             :  * by the client.  Returns true on success, false on failure.
     786             :  */
     787             : static bool
     788          40 : calculate_client_proof(fe_scram_state *state,
     789             :                        const char *client_final_message_without_proof,
     790             :                        uint8 *result)
     791             : {
     792             :     uint8       StoredKey[SCRAM_KEY_LEN];
     793             :     uint8       ClientKey[SCRAM_KEY_LEN];
     794             :     uint8       ClientSignature[SCRAM_KEY_LEN];
     795             :     int         i;
     796             :     pg_hmac_ctx *ctx;
     797             : 
     798          40 :     ctx = pg_hmac_create(PG_SHA256);
     799          40 :     if (ctx == NULL)
     800           0 :         return false;
     801             : 
     802             :     /*
     803             :      * Calculate SaltedPassword, and store it in 'state' so that we can reuse
     804             :      * it later in verify_server_signature.
     805             :      */
     806          40 :     if (scram_SaltedPassword(state->password, state->salt, state->saltlen,
     807          80 :                              state->iterations, state->SaltedPassword) < 0 ||
     808          80 :         scram_ClientKey(state->SaltedPassword, ClientKey) < 0 ||
     809          80 :         scram_H(ClientKey, SCRAM_KEY_LEN, StoredKey) < 0 ||
     810          80 :         pg_hmac_init(ctx, StoredKey, SCRAM_KEY_LEN) < 0 ||
     811          40 :         pg_hmac_update(ctx,
     812          40 :                        (uint8 *) state->client_first_message_bare,
     813          80 :                        strlen(state->client_first_message_bare)) < 0 ||
     814          80 :         pg_hmac_update(ctx, (uint8 *) ",", 1) < 0 ||
     815          40 :         pg_hmac_update(ctx,
     816          40 :                        (uint8 *) state->server_first_message,
     817          80 :                        strlen(state->server_first_message)) < 0 ||
     818          80 :         pg_hmac_update(ctx, (uint8 *) ",", 1) < 0 ||
     819          40 :         pg_hmac_update(ctx,
     820             :                        (uint8 *) client_final_message_without_proof,
     821          40 :                        strlen(client_final_message_without_proof)) < 0 ||
     822          40 :         pg_hmac_final(ctx, ClientSignature, sizeof(ClientSignature)) < 0)
     823             :     {
     824           0 :         pg_hmac_free(ctx);
     825           0 :         return false;
     826             :     }
     827             : 
     828        1320 :     for (i = 0; i < SCRAM_KEY_LEN; i++)
     829        1280 :         result[i] = ClientKey[i] ^ ClientSignature[i];
     830             : 
     831          40 :     pg_hmac_free(ctx);
     832          40 :     return true;
     833             : }
     834             : 
     835             : /*
     836             :  * Validate the server signature, received as part of the final exchange
     837             :  * message received from the server.  *match tracks if the server signature
     838             :  * matched or not. Returns true if the server signature got verified, and
     839             :  * false for a processing error.
     840             :  */
     841             : static bool
     842          28 : verify_server_signature(fe_scram_state *state, bool *match)
     843             : {
     844             :     uint8       expected_ServerSignature[SCRAM_KEY_LEN];
     845             :     uint8       ServerKey[SCRAM_KEY_LEN];
     846             :     pg_hmac_ctx *ctx;
     847             : 
     848          28 :     ctx = pg_hmac_create(PG_SHA256);
     849          28 :     if (ctx == NULL)
     850           0 :         return false;
     851             : 
     852          56 :     if (scram_ServerKey(state->SaltedPassword, ServerKey) < 0 ||
     853             :     /* calculate ServerSignature */
     854          56 :         pg_hmac_init(ctx, ServerKey, SCRAM_KEY_LEN) < 0 ||
     855          28 :         pg_hmac_update(ctx,
     856          28 :                        (uint8 *) state->client_first_message_bare,
     857          56 :                        strlen(state->client_first_message_bare)) < 0 ||
     858          56 :         pg_hmac_update(ctx, (uint8 *) ",", 1) < 0 ||
     859          28 :         pg_hmac_update(ctx,
     860          28 :                        (uint8 *) state->server_first_message,
     861          56 :                        strlen(state->server_first_message)) < 0 ||
     862          56 :         pg_hmac_update(ctx, (uint8 *) ",", 1) < 0 ||
     863          28 :         pg_hmac_update(ctx,
     864          28 :                        (uint8 *) state->client_final_message_without_proof,
     865          56 :                        strlen(state->client_final_message_without_proof)) < 0 ||
     866          28 :         pg_hmac_final(ctx, expected_ServerSignature,
     867             :                       sizeof(expected_ServerSignature)) < 0)
     868             :     {
     869           0 :         pg_hmac_free(ctx);
     870           0 :         return false;
     871             :     }
     872             : 
     873          28 :     pg_hmac_free(ctx);
     874             : 
     875             :     /* signature processed, so now check after it */
     876          28 :     if (memcmp(expected_ServerSignature, state->ServerSignature, SCRAM_KEY_LEN) != 0)
     877           0 :         *match = false;
     878             :     else
     879          28 :         *match = true;
     880             : 
     881          28 :     return true;
     882             : }
     883             : 
     884             : /*
     885             :  * Build a new SCRAM secret.
     886             :  */
     887             : char *
     888           0 : pg_fe_scram_build_secret(const char *password)
     889             : {
     890             :     char       *prep_password;
     891             :     pg_saslprep_rc rc;
     892             :     char        saltbuf[SCRAM_DEFAULT_SALT_LEN];
     893             :     char       *result;
     894             : 
     895             :     /*
     896             :      * Normalize the password with SASLprep.  If that doesn't work, because
     897             :      * the password isn't valid UTF-8 or contains prohibited characters, just
     898             :      * proceed with the original password.  (See comments at top of file.)
     899             :      */
     900           0 :     rc = pg_saslprep(password, &prep_password);
     901           0 :     if (rc == SASLPREP_OOM)
     902           0 :         return NULL;
     903           0 :     if (rc == SASLPREP_SUCCESS)
     904           0 :         password = (const char *) prep_password;
     905             : 
     906             :     /* Generate a random salt */
     907           0 :     if (!pg_strong_random(saltbuf, SCRAM_DEFAULT_SALT_LEN))
     908             :     {
     909           0 :         if (prep_password)
     910           0 :             free(prep_password);
     911           0 :         return NULL;
     912             :     }
     913             : 
     914           0 :     result = scram_build_secret(saltbuf, SCRAM_DEFAULT_SALT_LEN,
     915             :                                 SCRAM_DEFAULT_ITERATIONS, password);
     916             : 
     917           0 :     if (prep_password)
     918           0 :         free(prep_password);
     919             : 
     920           0 :     return result;
     921             : }

Generated by: LCOV version 1.14