LCOV - code coverage report
Current view: top level - src/backend/libpq - auth-oauth.c (source / functions) Hit Total Coverage
Test: PostgreSQL 18devel Lines: 0 233 0.0 %
Date: 2025-02-22 07:14:56 Functions: 0 12 0.0 %
Legend: Lines: hit not hit

          Line data    Source code
       1             : /*-------------------------------------------------------------------------
       2             :  *
       3             :  * auth-oauth.c
       4             :  *    Server-side implementation of the SASL OAUTHBEARER mechanism.
       5             :  *
       6             :  * See the following RFC for more details:
       7             :  * - RFC 7628: https://datatracker.ietf.org/doc/html/rfc7628
       8             :  *
       9             :  * Portions Copyright (c) 1996-2025, PostgreSQL Global Development Group
      10             :  * Portions Copyright (c) 1994, Regents of the University of California
      11             :  *
      12             :  * src/backend/libpq/auth-oauth.c
      13             :  *
      14             :  *-------------------------------------------------------------------------
      15             :  */
      16             : #include "postgres.h"
      17             : 
      18             : #include <unistd.h>
      19             : #include <fcntl.h>
      20             : 
      21             : #include "common/oauth-common.h"
      22             : #include "fmgr.h"
      23             : #include "lib/stringinfo.h"
      24             : #include "libpq/auth.h"
      25             : #include "libpq/hba.h"
      26             : #include "libpq/oauth.h"
      27             : #include "libpq/sasl.h"
      28             : #include "storage/fd.h"
      29             : #include "storage/ipc.h"
      30             : #include "utils/json.h"
      31             : #include "utils/varlena.h"
      32             : 
      33             : /* GUC */
      34             : char       *oauth_validator_libraries_string = NULL;
      35             : 
      36             : static void oauth_get_mechanisms(Port *port, StringInfo buf);
      37             : static void *oauth_init(Port *port, const char *selected_mech, const char *shadow_pass);
      38             : static int  oauth_exchange(void *opaq, const char *input, int inputlen,
      39             :                            char **output, int *outputlen, const char **logdetail);
      40             : 
      41             : static void load_validator_library(const char *libname);
      42             : static void shutdown_validator_library(void *arg);
      43             : 
      44             : static ValidatorModuleState *validator_module_state;
      45             : static const OAuthValidatorCallbacks *ValidatorCallbacks;
      46             : 
      47             : /* Mechanism declaration */
      48             : const pg_be_sasl_mech pg_be_oauth_mech = {
      49             :     .get_mechanisms = oauth_get_mechanisms,
      50             :     .init = oauth_init,
      51             :     .exchange = oauth_exchange,
      52             : 
      53             :     .max_message_length = PG_MAX_AUTH_TOKEN_LENGTH,
      54             : };
      55             : 
      56             : /* Valid states for the oauth_exchange() machine. */
      57             : enum oauth_state
      58             : {
      59             :     OAUTH_STATE_INIT = 0,
      60             :     OAUTH_STATE_ERROR,
      61             :     OAUTH_STATE_FINISHED,
      62             : };
      63             : 
      64             : /* Mechanism callback state. */
      65             : struct oauth_ctx
      66             : {
      67             :     enum oauth_state state;
      68             :     Port       *port;
      69             :     const char *issuer;
      70             :     const char *scope;
      71             : };
      72             : 
      73             : static char *sanitize_char(char c);
      74             : static char *parse_kvpairs_for_auth(char **input);
      75             : static void generate_error_response(struct oauth_ctx *ctx, char **output, int *outputlen);
      76             : static bool validate(Port *port, const char *auth);
      77             : 
      78             : /* Constants seen in an OAUTHBEARER client initial response. */
      79             : #define KVSEP 0x01              /* separator byte for key/value pairs */
      80             : #define AUTH_KEY "auth"           /* key containing the Authorization header */
      81             : #define BEARER_SCHEME "Bearer " /* required header scheme (case-insensitive!) */
      82             : 
      83             : /*
      84             :  * Retrieves the OAUTHBEARER mechanism list (currently a single item).
      85             :  *
      86             :  * For a full description of the API, see libpq/sasl.h.
      87             :  */
      88             : static void
      89           0 : oauth_get_mechanisms(Port *port, StringInfo buf)
      90             : {
      91             :     /* Only OAUTHBEARER is supported. */
      92           0 :     appendStringInfoString(buf, OAUTHBEARER_NAME);
      93           0 :     appendStringInfoChar(buf, '\0');
      94           0 : }
      95             : 
      96             : /*
      97             :  * Initializes mechanism state and loads the configured validator module.
      98             :  *
      99             :  * For a full description of the API, see libpq/sasl.h.
     100             :  */
     101             : static void *
     102           0 : oauth_init(Port *port, const char *selected_mech, const char *shadow_pass)
     103             : {
     104             :     struct oauth_ctx *ctx;
     105             : 
     106           0 :     if (strcmp(selected_mech, OAUTHBEARER_NAME) != 0)
     107           0 :         ereport(ERROR,
     108             :                 errcode(ERRCODE_PROTOCOL_VIOLATION),
     109             :                 errmsg("client selected an invalid SASL authentication mechanism"));
     110             : 
     111           0 :     ctx = palloc0(sizeof(*ctx));
     112             : 
     113           0 :     ctx->state = OAUTH_STATE_INIT;
     114           0 :     ctx->port = port;
     115             : 
     116             :     Assert(port->hba);
     117           0 :     ctx->issuer = port->hba->oauth_issuer;
     118           0 :     ctx->scope = port->hba->oauth_scope;
     119             : 
     120           0 :     load_validator_library(port->hba->oauth_validator);
     121             : 
     122           0 :     return ctx;
     123             : }
     124             : 
     125             : /*
     126             :  * Implements the OAUTHBEARER SASL exchange (RFC 7628, Sec. 3.2). This pulls
     127             :  * apart the client initial response and validates the Bearer token. It also
     128             :  * handles the dummy error response for a failed handshake, as described in
     129             :  * Sec. 3.2.3.
     130             :  *
     131             :  * For a full description of the API, see libpq/sasl.h.
     132             :  */
     133             : static int
     134           0 : oauth_exchange(void *opaq, const char *input, int inputlen,
     135             :                char **output, int *outputlen, const char **logdetail)
     136             : {
     137             :     char       *input_copy;
     138             :     char       *p;
     139             :     char        cbind_flag;
     140             :     char       *auth;
     141             :     int         status;
     142             : 
     143           0 :     struct oauth_ctx *ctx = opaq;
     144             : 
     145           0 :     *output = NULL;
     146           0 :     *outputlen = -1;
     147             : 
     148             :     /*
     149             :      * If the client didn't include an "Initial Client Response" in the
     150             :      * SASLInitialResponse message, send an empty challenge, to which the
     151             :      * client will respond with the same data that usually comes in the
     152             :      * Initial Client Response.
     153             :      */
     154           0 :     if (input == NULL)
     155             :     {
     156             :         Assert(ctx->state == OAUTH_STATE_INIT);
     157             : 
     158           0 :         *output = pstrdup("");
     159           0 :         *outputlen = 0;
     160           0 :         return PG_SASL_EXCHANGE_CONTINUE;
     161             :     }
     162             : 
     163             :     /*
     164             :      * Check that the input length agrees with the string length of the input.
     165             :      */
     166           0 :     if (inputlen == 0)
     167           0 :         ereport(ERROR,
     168             :                 errcode(ERRCODE_PROTOCOL_VIOLATION),
     169             :                 errmsg("malformed OAUTHBEARER message"),
     170             :                 errdetail("The message is empty."));
     171           0 :     if (inputlen != strlen(input))
     172           0 :         ereport(ERROR,
     173             :                 errcode(ERRCODE_PROTOCOL_VIOLATION),
     174             :                 errmsg("malformed OAUTHBEARER message"),
     175             :                 errdetail("Message length does not match input length."));
     176             : 
     177           0 :     switch (ctx->state)
     178             :     {
     179           0 :         case OAUTH_STATE_INIT:
     180             :             /* Handle this case below. */
     181           0 :             break;
     182             : 
     183           0 :         case OAUTH_STATE_ERROR:
     184             : 
     185             :             /*
     186             :              * Only one response is valid for the client during authentication
     187             :              * failure: a single kvsep.
     188             :              */
     189           0 :             if (inputlen != 1 || *input != KVSEP)
     190           0 :                 ereport(ERROR,
     191             :                         errcode(ERRCODE_PROTOCOL_VIOLATION),
     192             :                         errmsg("malformed OAUTHBEARER message"),
     193             :                         errdetail("Client did not send a kvsep response."));
     194             : 
     195             :             /* The (failed) handshake is now complete. */
     196           0 :             ctx->state = OAUTH_STATE_FINISHED;
     197           0 :             return PG_SASL_EXCHANGE_FAILURE;
     198             : 
     199           0 :         default:
     200           0 :             elog(ERROR, "invalid OAUTHBEARER exchange state");
     201             :             return PG_SASL_EXCHANGE_FAILURE;
     202             :     }
     203             : 
     204             :     /* Handle the client's initial message. */
     205           0 :     p = input_copy = pstrdup(input);
     206             : 
     207             :     /*
     208             :      * OAUTHBEARER does not currently define a channel binding (so there is no
     209             :      * OAUTHBEARER-PLUS, and we do not accept a 'p' specifier). We accept a
     210             :      * 'y' specifier purely for the remote chance that a future specification
     211             :      * could define one; then future clients can still interoperate with this
     212             :      * server implementation. 'n' is the expected case.
     213             :      */
     214           0 :     cbind_flag = *p;
     215           0 :     switch (cbind_flag)
     216             :     {
     217           0 :         case 'p':
     218           0 :             ereport(ERROR,
     219             :                     errcode(ERRCODE_PROTOCOL_VIOLATION),
     220             :                     errmsg("malformed OAUTHBEARER message"),
     221             :                     errdetail("The server does not support channel binding for OAuth, but the client message includes channel binding data."));
     222             :             break;
     223             : 
     224           0 :         case 'y':               /* fall through */
     225             :         case 'n':
     226           0 :             p++;
     227           0 :             if (*p != ',')
     228           0 :                 ereport(ERROR,
     229             :                         errcode(ERRCODE_PROTOCOL_VIOLATION),
     230             :                         errmsg("malformed OAUTHBEARER message"),
     231             :                         errdetail("Comma expected, but found character \"%s\".",
     232             :                                   sanitize_char(*p)));
     233           0 :             p++;
     234           0 :             break;
     235             : 
     236           0 :         default:
     237           0 :             ereport(ERROR,
     238             :                     errcode(ERRCODE_PROTOCOL_VIOLATION),
     239             :                     errmsg("malformed OAUTHBEARER message"),
     240             :                     errdetail("Unexpected channel-binding flag \"%s\".",
     241             :                               sanitize_char(cbind_flag)));
     242             :     }
     243             : 
     244             :     /*
     245             :      * Forbid optional authzid (authorization identity).  We don't support it.
     246             :      */
     247           0 :     if (*p == 'a')
     248           0 :         ereport(ERROR,
     249             :                 errcode(ERRCODE_FEATURE_NOT_SUPPORTED),
     250             :                 errmsg("client uses authorization identity, but it is not supported"));
     251           0 :     if (*p != ',')
     252           0 :         ereport(ERROR,
     253             :                 errcode(ERRCODE_PROTOCOL_VIOLATION),
     254             :                 errmsg("malformed OAUTHBEARER message"),
     255             :                 errdetail("Unexpected attribute \"%s\" in client-first-message.",
     256             :                           sanitize_char(*p)));
     257           0 :     p++;
     258             : 
     259             :     /* All remaining fields are separated by the RFC's kvsep (\x01). */
     260           0 :     if (*p != KVSEP)
     261           0 :         ereport(ERROR,
     262             :                 errcode(ERRCODE_PROTOCOL_VIOLATION),
     263             :                 errmsg("malformed OAUTHBEARER message"),
     264             :                 errdetail("Key-value separator expected, but found character \"%s\".",
     265             :                           sanitize_char(*p)));
     266           0 :     p++;
     267             : 
     268           0 :     auth = parse_kvpairs_for_auth(&p);
     269           0 :     if (!auth)
     270           0 :         ereport(ERROR,
     271             :                 errcode(ERRCODE_PROTOCOL_VIOLATION),
     272             :                 errmsg("malformed OAUTHBEARER message"),
     273             :                 errdetail("Message does not contain an auth value."));
     274             : 
     275             :     /* We should be at the end of our message. */
     276           0 :     if (*p)
     277           0 :         ereport(ERROR,
     278             :                 errcode(ERRCODE_PROTOCOL_VIOLATION),
     279             :                 errmsg("malformed OAUTHBEARER message"),
     280             :                 errdetail("Message contains additional data after the final terminator."));
     281             : 
     282           0 :     if (!validate(ctx->port, auth))
     283             :     {
     284           0 :         generate_error_response(ctx, output, outputlen);
     285             : 
     286           0 :         ctx->state = OAUTH_STATE_ERROR;
     287           0 :         status = PG_SASL_EXCHANGE_CONTINUE;
     288             :     }
     289             :     else
     290             :     {
     291           0 :         ctx->state = OAUTH_STATE_FINISHED;
     292           0 :         status = PG_SASL_EXCHANGE_SUCCESS;
     293             :     }
     294             : 
     295             :     /* Don't let extra copies of the bearer token hang around. */
     296           0 :     explicit_bzero(input_copy, inputlen);
     297             : 
     298           0 :     return status;
     299             : }
     300             : 
     301             : /*
     302             :  * Convert an arbitrary byte to printable form.  For error messages.
     303             :  *
     304             :  * If it's a printable ASCII character, print it as a single character.
     305             :  * otherwise, print it in hex.
     306             :  *
     307             :  * The returned pointer points to a static buffer.
     308             :  */
     309             : static char *
     310           0 : sanitize_char(char c)
     311             : {
     312             :     static char buf[5];
     313             : 
     314           0 :     if (c >= 0x21 && c <= 0x7E)
     315           0 :         snprintf(buf, sizeof(buf), "'%c'", c);
     316             :     else
     317           0 :         snprintf(buf, sizeof(buf), "0x%02x", (unsigned char) c);
     318           0 :     return buf;
     319             : }
     320             : 
     321             : /*
     322             :  * Performs syntactic validation of a key and value from the initial client
     323             :  * response. (Semantic validation of interesting values must be performed
     324             :  * later.)
     325             :  */
     326             : static void
     327           0 : validate_kvpair(const char *key, const char *val)
     328             : {
     329             :     /*-----
     330             :      * From Sec 3.1:
     331             :      *     key            = 1*(ALPHA)
     332             :      */
     333             :     static const char *key_allowed_set =
     334             :         "abcdefghijklmnopqrstuvwxyz"
     335             :         "ABCDEFGHIJKLMNOPQRSTUVWXYZ";
     336             : 
     337             :     size_t      span;
     338             : 
     339           0 :     if (!key[0])
     340           0 :         ereport(ERROR,
     341             :                 errcode(ERRCODE_PROTOCOL_VIOLATION),
     342             :                 errmsg("malformed OAUTHBEARER message"),
     343             :                 errdetail("Message contains an empty key name."));
     344             : 
     345           0 :     span = strspn(key, key_allowed_set);
     346           0 :     if (key[span] != '\0')
     347           0 :         ereport(ERROR,
     348             :                 errcode(ERRCODE_PROTOCOL_VIOLATION),
     349             :                 errmsg("malformed OAUTHBEARER message"),
     350             :                 errdetail("Message contains an invalid key name."));
     351             : 
     352             :     /*-----
     353             :      * From Sec 3.1:
     354             :      *     value          = *(VCHAR / SP / HTAB / CR / LF )
     355             :      *
     356             :      * The VCHAR (visible character) class is large; a loop is more
     357             :      * straightforward than strspn().
     358             :      */
     359           0 :     for (; *val; ++val)
     360             :     {
     361           0 :         if (0x21 <= *val && *val <= 0x7E)
     362           0 :             continue;           /* VCHAR */
     363             : 
     364           0 :         switch (*val)
     365             :         {
     366           0 :             case ' ':
     367             :             case '\t':
     368             :             case '\r':
     369             :             case '\n':
     370           0 :                 continue;       /* SP, HTAB, CR, LF */
     371             : 
     372           0 :             default:
     373           0 :                 ereport(ERROR,
     374             :                         errcode(ERRCODE_PROTOCOL_VIOLATION),
     375             :                         errmsg("malformed OAUTHBEARER message"),
     376             :                         errdetail("Message contains an invalid value."));
     377             :         }
     378             :     }
     379           0 : }
     380             : 
     381             : /*
     382             :  * Consumes all kvpairs in an OAUTHBEARER exchange message. If the "auth" key is
     383             :  * found, its value is returned.
     384             :  */
     385             : static char *
     386           0 : parse_kvpairs_for_auth(char **input)
     387             : {
     388           0 :     char       *pos = *input;
     389           0 :     char       *auth = NULL;
     390             : 
     391             :     /*----
     392             :      * The relevant ABNF, from Sec. 3.1:
     393             :      *
     394             :      *     kvsep          = %x01
     395             :      *     key            = 1*(ALPHA)
     396             :      *     value          = *(VCHAR / SP / HTAB / CR / LF )
     397             :      *     kvpair         = key "=" value kvsep
     398             :      *   ;;gs2-header     = See RFC 5801
     399             :      *     client-resp    = (gs2-header kvsep *kvpair kvsep) / kvsep
     400             :      *
     401             :      * By the time we reach this code, the gs2-header and initial kvsep have
     402             :      * already been validated. We start at the beginning of the first kvpair.
     403             :      */
     404             : 
     405           0 :     while (*pos)
     406             :     {
     407             :         char       *end;
     408             :         char       *sep;
     409             :         char       *key;
     410             :         char       *value;
     411             : 
     412             :         /*
     413             :          * Find the end of this kvpair. Note that input is null-terminated by
     414             :          * the SASL code, so the strchr() is bounded.
     415             :          */
     416           0 :         end = strchr(pos, KVSEP);
     417           0 :         if (!end)
     418           0 :             ereport(ERROR,
     419             :                     errcode(ERRCODE_PROTOCOL_VIOLATION),
     420             :                     errmsg("malformed OAUTHBEARER message"),
     421             :                     errdetail("Message contains an unterminated key/value pair."));
     422           0 :         *end = '\0';
     423             : 
     424           0 :         if (pos == end)
     425             :         {
     426             :             /* Empty kvpair, signifying the end of the list. */
     427           0 :             *input = pos + 1;
     428           0 :             return auth;
     429             :         }
     430             : 
     431             :         /*
     432             :          * Find the end of the key name.
     433             :          */
     434           0 :         sep = strchr(pos, '=');
     435           0 :         if (!sep)
     436           0 :             ereport(ERROR,
     437             :                     errcode(ERRCODE_PROTOCOL_VIOLATION),
     438             :                     errmsg("malformed OAUTHBEARER message"),
     439             :                     errdetail("Message contains a key without a value."));
     440           0 :         *sep = '\0';
     441             : 
     442             :         /* Both key and value are now safely terminated. */
     443           0 :         key = pos;
     444           0 :         value = sep + 1;
     445           0 :         validate_kvpair(key, value);
     446             : 
     447           0 :         if (strcmp(key, AUTH_KEY) == 0)
     448             :         {
     449           0 :             if (auth)
     450           0 :                 ereport(ERROR,
     451             :                         errcode(ERRCODE_PROTOCOL_VIOLATION),
     452             :                         errmsg("malformed OAUTHBEARER message"),
     453             :                         errdetail("Message contains multiple auth values."));
     454             : 
     455           0 :             auth = value;
     456             :         }
     457             :         else
     458             :         {
     459             :             /*
     460             :              * The RFC also defines the host and port keys, but they are not
     461             :              * required for OAUTHBEARER and we do not use them. Also, per Sec.
     462             :              * 3.1, any key/value pairs we don't recognize must be ignored.
     463             :              */
     464             :         }
     465             : 
     466             :         /* Move to the next pair. */
     467           0 :         pos = end + 1;
     468             :     }
     469             : 
     470           0 :     ereport(ERROR,
     471             :             errcode(ERRCODE_PROTOCOL_VIOLATION),
     472             :             errmsg("malformed OAUTHBEARER message"),
     473             :             errdetail("Message did not contain a final terminator."));
     474             : 
     475             :     pg_unreachable();
     476             :     return NULL;
     477             : }
     478             : 
     479             : /*
     480             :  * Builds the JSON response for failed authentication (RFC 7628, Sec. 3.2.2).
     481             :  * This contains the required scopes for entry and a pointer to the OAuth/OpenID
     482             :  * discovery document, which the client may use to conduct its OAuth flow.
     483             :  */
     484             : static void
     485           0 : generate_error_response(struct oauth_ctx *ctx, char **output, int *outputlen)
     486             : {
     487             :     StringInfoData buf;
     488             :     StringInfoData issuer;
     489             : 
     490             :     /*
     491             :      * The admin needs to set an issuer and scope for OAuth to work. There's
     492             :      * not really a way to hide this from the user, either, because we can't
     493             :      * choose a "default" issuer, so be honest in the failure message. (In
     494             :      * practice such configurations are rejected during HBA parsing.)
     495             :      */
     496           0 :     if (!ctx->issuer || !ctx->scope)
     497           0 :         ereport(FATAL,
     498             :                 errcode(ERRCODE_INTERNAL_ERROR),
     499             :                 errmsg("OAuth is not properly configured for this user"),
     500             :                 errdetail_log("The issuer and scope parameters must be set in pg_hba.conf."));
     501             : 
     502             :     /*
     503             :      * Build a default .well-known URI based on our issuer, unless the HBA has
     504             :      * already provided one.
     505             :      */
     506           0 :     initStringInfo(&issuer);
     507           0 :     appendStringInfoString(&issuer, ctx->issuer);
     508           0 :     if (strstr(ctx->issuer, "/.well-known/") == NULL)
     509           0 :         appendStringInfoString(&issuer, "/.well-known/openid-configuration");
     510             : 
     511           0 :     initStringInfo(&buf);
     512             : 
     513             :     /*
     514             :      * Escaping the string here is belt-and-suspenders defensive programming
     515             :      * since escapable characters aren't valid in either the issuer URI or the
     516             :      * scope list, but the HBA doesn't enforce that yet.
     517             :      */
     518           0 :     appendStringInfoString(&buf, "{ \"status\": \"invalid_token\", ");
     519             : 
     520           0 :     appendStringInfoString(&buf, "\"openid-configuration\": ");
     521           0 :     escape_json(&buf, issuer.data);
     522           0 :     pfree(issuer.data);
     523             : 
     524           0 :     appendStringInfoString(&buf, ", \"scope\": ");
     525           0 :     escape_json(&buf, ctx->scope);
     526             : 
     527           0 :     appendStringInfoString(&buf, " }");
     528             : 
     529           0 :     *output = buf.data;
     530           0 :     *outputlen = buf.len;
     531           0 : }
     532             : 
     533             : /*-----
     534             :  * Validates the provided Authorization header and returns the token from
     535             :  * within it. NULL is returned on validation failure.
     536             :  *
     537             :  * Only Bearer tokens are accepted. The ABNF is defined in RFC 6750, Sec.
     538             :  * 2.1:
     539             :  *
     540             :  *      b64token    = 1*( ALPHA / DIGIT /
     541             :  *                        "-" / "." / "_" / "~" / "+" / "/" ) *"="
     542             :  *      credentials = "Bearer" 1*SP b64token
     543             :  *
     544             :  * The "credentials" construction is what we receive in our auth value.
     545             :  *
     546             :  * Since that spec is subordinate to HTTP (i.e. the HTTP Authorization
     547             :  * header format; RFC 9110 Sec. 11), the "Bearer" scheme string must be
     548             :  * compared case-insensitively. (This is not mentioned in RFC 6750, but the
     549             :  * OAUTHBEARER spec points it out: RFC 7628 Sec. 4.)
     550             :  *
     551             :  * Invalid formats are technically a protocol violation, but we shouldn't
     552             :  * reflect any information about the sensitive Bearer token back to the
     553             :  * client; log at COMMERROR instead.
     554             :  */
     555             : static const char *
     556           0 : validate_token_format(const char *header)
     557             : {
     558             :     size_t      span;
     559             :     const char *token;
     560             :     static const char *const b64token_allowed_set =
     561             :         "abcdefghijklmnopqrstuvwxyz"
     562             :         "ABCDEFGHIJKLMNOPQRSTUVWXYZ"
     563             :         "0123456789-._~+/";
     564             : 
     565             :     /* Missing auth headers should be handled by the caller. */
     566             :     Assert(header);
     567             : 
     568           0 :     if (header[0] == '\0')
     569             :     {
     570             :         /*
     571             :          * A completely empty auth header represents a query for
     572             :          * authentication parameters. The client expects it to fail; there's
     573             :          * no need to make any extra noise in the logs.
     574             :          *
     575             :          * TODO: should we find a way to return STATUS_EOF at the top level,
     576             :          * to suppress the authentication error entirely?
     577             :          */
     578           0 :         return NULL;
     579             :     }
     580             : 
     581           0 :     if (pg_strncasecmp(header, BEARER_SCHEME, strlen(BEARER_SCHEME)))
     582             :     {
     583           0 :         ereport(COMMERROR,
     584             :                 errcode(ERRCODE_PROTOCOL_VIOLATION),
     585             :                 errmsg("malformed OAuth bearer token"),
     586             :                 errdetail_log("Client response indicated a non-Bearer authentication scheme."));
     587           0 :         return NULL;
     588             :     }
     589             : 
     590             :     /* Pull the bearer token out of the auth value. */
     591           0 :     token = header + strlen(BEARER_SCHEME);
     592             : 
     593             :     /* Swallow any additional spaces. */
     594           0 :     while (*token == ' ')
     595           0 :         token++;
     596             : 
     597             :     /* Tokens must not be empty. */
     598           0 :     if (!*token)
     599             :     {
     600           0 :         ereport(COMMERROR,
     601             :                 errcode(ERRCODE_PROTOCOL_VIOLATION),
     602             :                 errmsg("malformed OAuth bearer token"),
     603             :                 errdetail_log("Bearer token is empty."));
     604           0 :         return NULL;
     605             :     }
     606             : 
     607             :     /*
     608             :      * Make sure the token contains only allowed characters. Tokens may end
     609             :      * with any number of '=' characters.
     610             :      */
     611           0 :     span = strspn(token, b64token_allowed_set);
     612           0 :     while (token[span] == '=')
     613           0 :         span++;
     614             : 
     615           0 :     if (token[span] != '\0')
     616             :     {
     617             :         /*
     618             :          * This error message could be more helpful by printing the
     619             :          * problematic character(s), but that'd be a bit like printing a piece
     620             :          * of someone's password into the logs.
     621             :          */
     622           0 :         ereport(COMMERROR,
     623             :                 errcode(ERRCODE_PROTOCOL_VIOLATION),
     624             :                 errmsg("malformed OAuth bearer token"),
     625             :                 errdetail_log("Bearer token is not in the correct format."));
     626           0 :         return NULL;
     627             :     }
     628             : 
     629           0 :     return token;
     630             : }
     631             : 
     632             : /*
     633             :  * Checks that the "auth" kvpair in the client response contains a syntactically
     634             :  * valid Bearer token, then passes it along to the loaded validator module for
     635             :  * authorization. Returns true if validation succeeds.
     636             :  */
     637             : static bool
     638           0 : validate(Port *port, const char *auth)
     639             : {
     640             :     int         map_status;
     641             :     ValidatorModuleResult *ret;
     642             :     const char *token;
     643             :     bool        status;
     644             : 
     645             :     /* Ensure that we have a correct token to validate */
     646           0 :     if (!(token = validate_token_format(auth)))
     647           0 :         return false;
     648             : 
     649             :     /*
     650             :      * Ensure that we have a validation library loaded, this should always be
     651             :      * the case and an error here is indicative of a bug.
     652             :      */
     653           0 :     if (!ValidatorCallbacks || !ValidatorCallbacks->validate_cb)
     654           0 :         ereport(FATAL,
     655             :                 errcode(ERRCODE_INTERNAL_ERROR),
     656             :                 errmsg("validation of OAuth token requested without a validator loaded"));
     657             : 
     658             :     /* Call the validation function from the validator module */
     659           0 :     ret = palloc0(sizeof(ValidatorModuleResult));
     660           0 :     if (!ValidatorCallbacks->validate_cb(validator_module_state, token,
     661           0 :                                          port->user_name, ret))
     662             :     {
     663           0 :         ereport(WARNING,
     664             :                 errcode(ERRCODE_INTERNAL_ERROR),
     665             :                 errmsg("internal error in OAuth validator module"));
     666           0 :         return false;
     667             :     }
     668             : 
     669             :     /*
     670             :      * Log any authentication results even if the token isn't authorized; it
     671             :      * might be useful for auditing or troubleshooting.
     672             :      */
     673           0 :     if (ret->authn_id)
     674           0 :         set_authn_id(port, ret->authn_id);
     675             : 
     676           0 :     if (!ret->authorized)
     677             :     {
     678           0 :         ereport(LOG,
     679             :                 errmsg("OAuth bearer authentication failed for user \"%s\"",
     680             :                        port->user_name),
     681             :                 errdetail_log("Validator failed to authorize the provided token."));
     682             : 
     683           0 :         status = false;
     684           0 :         goto cleanup;
     685             :     }
     686             : 
     687           0 :     if (port->hba->oauth_skip_usermap)
     688             :     {
     689             :         /*
     690             :          * If the validator is our authorization authority, we're done.
     691             :          * Authentication may or may not have been performed depending on the
     692             :          * validator implementation; all that matters is that the validator
     693             :          * says the user can log in with the target role.
     694             :          */
     695           0 :         status = true;
     696           0 :         goto cleanup;
     697             :     }
     698             : 
     699             :     /* Make sure the validator authenticated the user. */
     700           0 :     if (ret->authn_id == NULL || ret->authn_id[0] == '\0')
     701             :     {
     702           0 :         ereport(LOG,
     703             :                 errmsg("OAuth bearer authentication failed for user \"%s\"",
     704             :                        port->user_name),
     705             :                 errdetail_log("Validator provided no identity."));
     706             : 
     707           0 :         status = false;
     708           0 :         goto cleanup;
     709             :     }
     710             : 
     711             :     /* Finally, check the user map. */
     712           0 :     map_status = check_usermap(port->hba->usermap, port->user_name,
     713             :                                MyClientConnectionInfo.authn_id, false);
     714           0 :     status = (map_status == STATUS_OK);
     715             : 
     716           0 : cleanup:
     717             : 
     718             :     /*
     719             :      * Clear and free the validation result from the validator module once
     720             :      * we're done with it.
     721             :      */
     722           0 :     if (ret->authn_id != NULL)
     723           0 :         pfree(ret->authn_id);
     724           0 :     pfree(ret);
     725             : 
     726           0 :     return status;
     727             : }
     728             : 
     729             : /*
     730             :  * load_validator_library
     731             :  *
     732             :  * Load the configured validator library in order to perform token validation.
     733             :  * There is no built-in fallback since validation is implementation specific. If
     734             :  * no validator library is configured, or if it fails to load, then error out
     735             :  * since token validation won't be possible.
     736             :  */
     737             : static void
     738           0 : load_validator_library(const char *libname)
     739             : {
     740             :     OAuthValidatorModuleInit validator_init;
     741             :     MemoryContextCallback *mcb;
     742             : 
     743             :     /*
     744             :      * The presence, and validity, of libname has already been established by
     745             :      * check_oauth_validator so we don't need to perform more than Assert
     746             :      * level checking here.
     747             :      */
     748             :     Assert(libname && *libname);
     749             : 
     750           0 :     validator_init = (OAuthValidatorModuleInit)
     751           0 :         load_external_function(libname, "_PG_oauth_validator_module_init",
     752             :                                false, NULL);
     753             : 
     754             :     /*
     755             :      * The validator init function is required since it will set the callbacks
     756             :      * for the validator library.
     757             :      */
     758           0 :     if (validator_init == NULL)
     759           0 :         ereport(ERROR,
     760             :                 errmsg("%s module \"%s\" must define the symbol %s",
     761             :                        "OAuth validator", libname, "_PG_oauth_validator_module_init"));
     762             : 
     763           0 :     ValidatorCallbacks = (*validator_init) ();
     764             :     Assert(ValidatorCallbacks);
     765             : 
     766             :     /*
     767             :      * Check the magic number, to protect against break-glass scenarios where
     768             :      * the ABI must change within a major version. load_external_function()
     769             :      * already checks for compatibility across major versions.
     770             :      */
     771           0 :     if (ValidatorCallbacks->magic != PG_OAUTH_VALIDATOR_MAGIC)
     772           0 :         ereport(ERROR,
     773             :                 errmsg("%s module \"%s\": magic number mismatch",
     774             :                        "OAuth validator", libname),
     775             :                 errdetail("Server has magic number 0x%08X, module has 0x%08X.",
     776             :                           PG_OAUTH_VALIDATOR_MAGIC, ValidatorCallbacks->magic));
     777             : 
     778             :     /*
     779             :      * Make sure all required callbacks are present in the ValidatorCallbacks
     780             :      * structure. Right now only the validation callback is required.
     781             :      */
     782           0 :     if (ValidatorCallbacks->validate_cb == NULL)
     783           0 :         ereport(ERROR,
     784             :                 errmsg("%s module \"%s\" must provide a %s callback",
     785             :                        "OAuth validator", libname, "validate_cb"));
     786             : 
     787             :     /* Allocate memory for validator library private state data */
     788           0 :     validator_module_state = (ValidatorModuleState *) palloc0(sizeof(ValidatorModuleState));
     789           0 :     validator_module_state->sversion = PG_VERSION_NUM;
     790             : 
     791           0 :     if (ValidatorCallbacks->startup_cb != NULL)
     792           0 :         ValidatorCallbacks->startup_cb(validator_module_state);
     793             : 
     794             :     /* Shut down the library before cleaning up its state. */
     795           0 :     mcb = palloc0(sizeof(*mcb));
     796           0 :     mcb->func = shutdown_validator_library;
     797             : 
     798           0 :     MemoryContextRegisterResetCallback(CurrentMemoryContext, mcb);
     799           0 : }
     800             : 
     801             : /*
     802             :  * Call the validator module's shutdown callback, if one is provided. This is
     803             :  * invoked during memory context reset.
     804             :  */
     805             : static void
     806           0 : shutdown_validator_library(void *arg)
     807             : {
     808           0 :     if (ValidatorCallbacks->shutdown_cb != NULL)
     809           0 :         ValidatorCallbacks->shutdown_cb(validator_module_state);
     810           0 : }
     811             : 
     812             : /*
     813             :  * Ensure an OAuth validator named in the HBA is permitted by the configuration.
     814             :  *
     815             :  * If the validator is currently unset and exactly one library is declared in
     816             :  * oauth_validator_libraries, then that library will be used as the validator.
     817             :  * Otherwise the name must be present in the list of oauth_validator_libraries.
     818             :  */
     819             : bool
     820           0 : check_oauth_validator(HbaLine *hbaline, int elevel, char **err_msg)
     821             : {
     822           0 :     int         line_num = hbaline->linenumber;
     823           0 :     const char *file_name = hbaline->sourcefile;
     824             :     char       *rawstring;
     825           0 :     List       *elemlist = NIL;
     826             : 
     827           0 :     *err_msg = NULL;
     828             : 
     829           0 :     if (oauth_validator_libraries_string[0] == '\0')
     830             :     {
     831           0 :         ereport(elevel,
     832             :                 errcode(ERRCODE_CONFIG_FILE_ERROR),
     833             :                 errmsg("oauth_validator_libraries must be set for authentication method %s",
     834             :                        "oauth"),
     835             :                 errcontext("line %d of configuration file \"%s\"",
     836             :                            line_num, file_name));
     837           0 :         *err_msg = psprintf("oauth_validator_libraries must be set for authentication method %s",
     838             :                             "oauth");
     839           0 :         return false;
     840             :     }
     841             : 
     842             :     /* SplitDirectoriesString needs a modifiable copy */
     843           0 :     rawstring = pstrdup(oauth_validator_libraries_string);
     844             : 
     845           0 :     if (!SplitDirectoriesString(rawstring, ',', &elemlist))
     846             :     {
     847             :         /* syntax error in list */
     848           0 :         ereport(elevel,
     849             :                 errcode(ERRCODE_CONFIG_FILE_ERROR),
     850             :                 errmsg("invalid list syntax in parameter \"%s\"",
     851             :                        "oauth_validator_libraries"));
     852           0 :         *err_msg = psprintf("invalid list syntax in parameter \"%s\"",
     853             :                             "oauth_validator_libraries");
     854           0 :         goto done;
     855             :     }
     856             : 
     857           0 :     if (!hbaline->oauth_validator)
     858             :     {
     859           0 :         if (elemlist->length == 1)
     860             :         {
     861           0 :             hbaline->oauth_validator = pstrdup(linitial(elemlist));
     862           0 :             goto done;
     863             :         }
     864             : 
     865           0 :         ereport(elevel,
     866             :                 errcode(ERRCODE_CONFIG_FILE_ERROR),
     867             :                 errmsg("authentication method \"oauth\" requires argument \"validator\" to be set when oauth_validator_libraries contains multiple options"),
     868             :                 errcontext("line %d of configuration file \"%s\"",
     869             :                            line_num, file_name));
     870           0 :         *err_msg = "authentication method \"oauth\" requires argument \"validator\" to be set when oauth_validator_libraries contains multiple options";
     871           0 :         goto done;
     872             :     }
     873             : 
     874           0 :     foreach_ptr(char, allowed, elemlist)
     875             :     {
     876           0 :         if (strcmp(allowed, hbaline->oauth_validator) == 0)
     877           0 :             goto done;
     878             :     }
     879             : 
     880           0 :     ereport(elevel,
     881             :             errcode(ERRCODE_INVALID_PARAMETER_VALUE),
     882             :             errmsg("validator \"%s\" is not permitted by %s",
     883             :                    hbaline->oauth_validator, "oauth_validator_libraries"),
     884             :             errcontext("line %d of configuration file \"%s\"",
     885             :                        line_num, file_name));
     886           0 :     *err_msg = psprintf("validator \"%s\" is not permitted by %s",
     887             :                         hbaline->oauth_validator, "oauth_validator_libraries");
     888             : 
     889           0 : done:
     890           0 :     list_free_deep(elemlist);
     891           0 :     pfree(rawstring);
     892             : 
     893           0 :     return (*err_msg == NULL);
     894             : }

Generated by: LCOV version 1.14