LCOV - code coverage report
Current view: top level - src/backend/libpq - auth-oauth.c (source / functions) Coverage Total Hit
Test: PostgreSQL 19devel Lines: 0.0 % 290 0
Test Date: 2026-05-02 10:16:34 Functions: 0.0 % 16 0
Legend: Lines:     hit not hit

            Line data    Source code
       1              : /*-------------------------------------------------------------------------
       2              :  *
       3              :  * auth-oauth.c
       4              :  *    Server-side implementation of the SASL OAUTHBEARER mechanism.
       5              :  *
       6              :  * See the following RFC for more details:
       7              :  * - RFC 7628: https://datatracker.ietf.org/doc/html/rfc7628
       8              :  *
       9              :  * Portions Copyright (c) 1996-2026, PostgreSQL Global Development Group
      10              :  * Portions Copyright (c) 1994, Regents of the University of California
      11              :  *
      12              :  * src/backend/libpq/auth-oauth.c
      13              :  *
      14              :  *-------------------------------------------------------------------------
      15              :  */
      16              : #include "postgres.h"
      17              : 
      18              : #include <unistd.h>
      19              : #include <fcntl.h>
      20              : 
      21              : #include "common/oauth-common.h"
      22              : #include "fmgr.h"
      23              : #include "lib/stringinfo.h"
      24              : #include "libpq/auth.h"
      25              : #include "libpq/hba.h"
      26              : #include "libpq/oauth.h"
      27              : #include "libpq/sasl.h"
      28              : #include "miscadmin.h"
      29              : #include "storage/fd.h"
      30              : #include "storage/ipc.h"
      31              : #include "utils/json.h"
      32              : #include "utils/varlena.h"
      33              : 
      34              : /* GUC */
      35              : char       *oauth_validator_libraries_string = NULL;
      36              : 
      37              : static void oauth_get_mechanisms(Port *port, StringInfo buf);
      38              : static void *oauth_init(Port *port, const char *selected_mech, const char *shadow_pass);
      39              : static int  oauth_exchange(void *opaq, const char *input, int inputlen,
      40              :                            char **output, int *outputlen, const char **logdetail);
      41              : 
      42              : static void load_validator_library(const char *libname);
      43              : static void shutdown_validator_library(void *arg);
      44              : static bool check_validator_hba_options(Port *port, const char **logdetail);
      45              : 
      46              : static ValidatorModuleState *validator_module_state;
      47              : static const OAuthValidatorCallbacks *ValidatorCallbacks;
      48              : 
      49              : static MemoryContext ValidatorMemoryContext;
      50              : static List *ValidatorOptions;
      51              : static bool ValidatorOptionsChecked;
      52              : 
      53              : /* Mechanism declaration */
      54              : const pg_be_sasl_mech pg_be_oauth_mech = {
      55              :     .get_mechanisms = oauth_get_mechanisms,
      56              :     .init = oauth_init,
      57              :     .exchange = oauth_exchange,
      58              : 
      59              :     .max_message_length = PG_MAX_AUTH_TOKEN_LENGTH,
      60              : };
      61              : 
      62              : /* Valid states for the oauth_exchange() machine. */
      63              : enum oauth_state
      64              : {
      65              :     OAUTH_STATE_INIT = 0,
      66              :     OAUTH_STATE_ERROR,
      67              :     OAUTH_STATE_ERROR_DISCOVERY,
      68              :     OAUTH_STATE_FINISHED,
      69              : };
      70              : 
      71              : /* Mechanism callback state. */
      72              : struct oauth_ctx
      73              : {
      74              :     enum oauth_state state;
      75              :     Port       *port;
      76              :     const char *issuer;
      77              :     const char *scope;
      78              : };
      79              : 
      80              : static char *sanitize_char(char c);
      81              : static char *parse_kvpairs_for_auth(char **input);
      82              : static void generate_error_response(struct oauth_ctx *ctx, char **output, int *outputlen);
      83              : static bool validate(Port *port, const char *auth, const char **logdetail);
      84              : 
      85              : /* Constants seen in an OAUTHBEARER client initial response. */
      86              : #define KVSEP 0x01              /* separator byte for key/value pairs */
      87              : #define AUTH_KEY "auth"           /* key containing the Authorization header */
      88              : #define BEARER_SCHEME "Bearer " /* required header scheme (case-insensitive!) */
      89              : 
      90              : /*
      91              :  * Retrieves the OAUTHBEARER mechanism list (currently a single item).
      92              :  *
      93              :  * For a full description of the API, see libpq/sasl.h.
      94              :  */
      95              : static void
      96            0 : oauth_get_mechanisms(Port *port, StringInfo buf)
      97              : {
      98              :     /* Only OAUTHBEARER is supported. */
      99            0 :     appendStringInfoString(buf, OAUTHBEARER_NAME);
     100            0 :     appendStringInfoChar(buf, '\0');
     101            0 : }
     102              : 
     103              : /*
     104              :  * Initializes mechanism state and loads the configured validator module.
     105              :  *
     106              :  * For a full description of the API, see libpq/sasl.h.
     107              :  */
     108              : static void *
     109            0 : oauth_init(Port *port, const char *selected_mech, const char *shadow_pass)
     110              : {
     111              :     struct oauth_ctx *ctx;
     112              : 
     113            0 :     if (strcmp(selected_mech, OAUTHBEARER_NAME) != 0)
     114            0 :         ereport(ERROR,
     115              :                 errcode(ERRCODE_PROTOCOL_VIOLATION),
     116              :                 errmsg("client selected an invalid SASL authentication mechanism"));
     117              : 
     118              :     /* Save our memory context for later use by client API calls. */
     119            0 :     ValidatorMemoryContext = CurrentMemoryContext;
     120              : 
     121            0 :     ctx = palloc0_object(struct oauth_ctx);
     122              : 
     123            0 :     ctx->state = OAUTH_STATE_INIT;
     124            0 :     ctx->port = port;
     125              : 
     126              :     Assert(port->hba);
     127            0 :     ctx->issuer = port->hba->oauth_issuer;
     128            0 :     ctx->scope = port->hba->oauth_scope;
     129              : 
     130            0 :     load_validator_library(port->hba->oauth_validator);
     131              : 
     132            0 :     return ctx;
     133              : }
     134              : 
     135              : /*
     136              :  * Implements the OAUTHBEARER SASL exchange (RFC 7628, Sec. 3.2). This pulls
     137              :  * apart the client initial response and validates the Bearer token. It also
     138              :  * handles the dummy error response for a failed handshake, as described in
     139              :  * Sec. 3.2.3.
     140              :  *
     141              :  * For a full description of the API, see libpq/sasl.h.
     142              :  */
     143              : static int
     144            0 : oauth_exchange(void *opaq, const char *input, int inputlen,
     145              :                char **output, int *outputlen, const char **logdetail)
     146              : {
     147              :     char       *input_copy;
     148              :     char       *p;
     149              :     char        cbind_flag;
     150              :     char       *auth;
     151              :     int         status;
     152              : 
     153            0 :     struct oauth_ctx *ctx = opaq;
     154              : 
     155            0 :     *output = NULL;
     156            0 :     *outputlen = -1;
     157              : 
     158              :     /*
     159              :      * If the client didn't include an "Initial Client Response" in the
     160              :      * SASLInitialResponse message, send an empty challenge, to which the
     161              :      * client will respond with the same data that usually comes in the
     162              :      * Initial Client Response.
     163              :      */
     164            0 :     if (input == NULL)
     165              :     {
     166              :         Assert(ctx->state == OAUTH_STATE_INIT);
     167              : 
     168            0 :         *output = pstrdup("");
     169            0 :         *outputlen = 0;
     170            0 :         return PG_SASL_EXCHANGE_CONTINUE;
     171              :     }
     172              : 
     173              :     /*
     174              :      * Check that the input length agrees with the string length of the input.
     175              :      */
     176            0 :     if (inputlen == 0)
     177            0 :         ereport(ERROR,
     178              :                 errcode(ERRCODE_PROTOCOL_VIOLATION),
     179              :                 errmsg("malformed OAUTHBEARER message"),
     180              :                 errdetail("The message is empty."));
     181            0 :     if (inputlen != strlen(input))
     182            0 :         ereport(ERROR,
     183              :                 errcode(ERRCODE_PROTOCOL_VIOLATION),
     184              :                 errmsg("malformed OAUTHBEARER message"),
     185              :                 errdetail("Message length does not match input length."));
     186              : 
     187            0 :     switch (ctx->state)
     188              :     {
     189            0 :         case OAUTH_STATE_INIT:
     190              :             /* Handle this case below. */
     191            0 :             break;
     192              : 
     193            0 :         case OAUTH_STATE_ERROR:
     194              :         case OAUTH_STATE_ERROR_DISCOVERY:
     195              : 
     196              :             /*
     197              :              * Only one response is valid for the client during authentication
     198              :              * failure: a single kvsep.
     199              :              */
     200            0 :             if (inputlen != 1 || *input != KVSEP)
     201            0 :                 ereport(ERROR,
     202              :                         errcode(ERRCODE_PROTOCOL_VIOLATION),
     203              :                         errmsg("malformed OAUTHBEARER message"),
     204              :                         errdetail("Client did not send a kvsep response."));
     205              : 
     206              :             /*
     207              :              * The (failed) handshake is now complete. Don't report discovery
     208              :              * requests in the server log unless the log level is high enough.
     209              :              */
     210            0 :             if (ctx->state == OAUTH_STATE_ERROR_DISCOVERY)
     211              :             {
     212            0 :                 ereport(DEBUG1, errmsg("OAuth issuer discovery requested"));
     213              : 
     214            0 :                 ctx->state = OAUTH_STATE_FINISHED;
     215            0 :                 return PG_SASL_EXCHANGE_ABANDONED;
     216              :             }
     217              : 
     218              :             /* We're not in discovery, so this is just a normal auth failure. */
     219            0 :             ctx->state = OAUTH_STATE_FINISHED;
     220            0 :             return PG_SASL_EXCHANGE_FAILURE;
     221              : 
     222            0 :         default:
     223            0 :             elog(ERROR, "invalid OAUTHBEARER exchange state");
     224              :             return PG_SASL_EXCHANGE_FAILURE;
     225              :     }
     226              : 
     227              :     /* Handle the client's initial message. */
     228            0 :     p = input_copy = pstrdup(input);
     229              : 
     230              :     /*
     231              :      * OAUTHBEARER does not currently define a channel binding (so there is no
     232              :      * OAUTHBEARER-PLUS, and we do not accept a 'p' specifier). We accept a
     233              :      * 'y' specifier purely for the remote chance that a future specification
     234              :      * could define one; then future clients can still interoperate with this
     235              :      * server implementation. 'n' is the expected case.
     236              :      */
     237            0 :     cbind_flag = *p;
     238            0 :     switch (cbind_flag)
     239              :     {
     240            0 :         case 'p':
     241            0 :             ereport(ERROR,
     242              :                     errcode(ERRCODE_PROTOCOL_VIOLATION),
     243              :                     errmsg("malformed OAUTHBEARER message"),
     244              :                     errdetail("The server does not support channel binding for OAuth, but the client message includes channel binding data."));
     245              :             break;
     246              : 
     247            0 :         case 'y':               /* fall through */
     248              :         case 'n':
     249            0 :             p++;
     250            0 :             if (*p != ',')
     251            0 :                 ereport(ERROR,
     252              :                         errcode(ERRCODE_PROTOCOL_VIOLATION),
     253              :                         errmsg("malformed OAUTHBEARER message"),
     254              :                         errdetail("Comma expected, but found character \"%s\".",
     255              :                                   sanitize_char(*p)));
     256            0 :             p++;
     257            0 :             break;
     258              : 
     259            0 :         default:
     260            0 :             ereport(ERROR,
     261              :                     errcode(ERRCODE_PROTOCOL_VIOLATION),
     262              :                     errmsg("malformed OAUTHBEARER message"),
     263              :                     errdetail("Unexpected channel-binding flag \"%s\".",
     264              :                               sanitize_char(cbind_flag)));
     265              :     }
     266              : 
     267              :     /*
     268              :      * Forbid optional authzid (authorization identity).  We don't support it.
     269              :      */
     270            0 :     if (*p == 'a')
     271            0 :         ereport(ERROR,
     272              :                 errcode(ERRCODE_FEATURE_NOT_SUPPORTED),
     273              :                 errmsg("client uses authorization identity, but it is not supported"));
     274            0 :     if (*p != ',')
     275            0 :         ereport(ERROR,
     276              :                 errcode(ERRCODE_PROTOCOL_VIOLATION),
     277              :                 errmsg("malformed OAUTHBEARER message"),
     278              :                 errdetail("Unexpected attribute \"%s\" in client-first-message.",
     279              :                           sanitize_char(*p)));
     280            0 :     p++;
     281              : 
     282              :     /* All remaining fields are separated by the RFC's kvsep (\x01). */
     283            0 :     if (*p != KVSEP)
     284            0 :         ereport(ERROR,
     285              :                 errcode(ERRCODE_PROTOCOL_VIOLATION),
     286              :                 errmsg("malformed OAUTHBEARER message"),
     287              :                 errdetail("Key-value separator expected, but found character \"%s\".",
     288              :                           sanitize_char(*p)));
     289            0 :     p++;
     290              : 
     291            0 :     auth = parse_kvpairs_for_auth(&p);
     292            0 :     if (!auth)
     293            0 :         ereport(ERROR,
     294              :                 errcode(ERRCODE_PROTOCOL_VIOLATION),
     295              :                 errmsg("malformed OAUTHBEARER message"),
     296              :                 errdetail("Message does not contain an auth value."));
     297              : 
     298              :     /* We should be at the end of our message. */
     299            0 :     if (*p)
     300            0 :         ereport(ERROR,
     301              :                 errcode(ERRCODE_PROTOCOL_VIOLATION),
     302              :                 errmsg("malformed OAUTHBEARER message"),
     303              :                 errdetail("Message contains additional data after the final terminator."));
     304              : 
     305              :     /*
     306              :      * Make sure all custom HBA options are understood by the validator before
     307              :      * continuing, since we couldn't check them during server start/reload.
     308              :      */
     309            0 :     if (!check_validator_hba_options(ctx->port, logdetail))
     310              :     {
     311            0 :         ctx->state = OAUTH_STATE_FINISHED;
     312            0 :         return PG_SASL_EXCHANGE_FAILURE;
     313              :     }
     314              : 
     315            0 :     if (auth[0] == '\0')
     316              :     {
     317              :         /*
     318              :          * An empty auth value represents a discovery request; the client
     319              :          * expects it to fail.  Skip validation entirely and move directly to
     320              :          * the error response.
     321              :          */
     322            0 :         generate_error_response(ctx, output, outputlen);
     323              : 
     324            0 :         ctx->state = OAUTH_STATE_ERROR_DISCOVERY;
     325            0 :         status = PG_SASL_EXCHANGE_CONTINUE;
     326              :     }
     327            0 :     else if (!validate(ctx->port, auth, logdetail))
     328              :     {
     329            0 :         generate_error_response(ctx, output, outputlen);
     330              : 
     331            0 :         ctx->state = OAUTH_STATE_ERROR;
     332            0 :         status = PG_SASL_EXCHANGE_CONTINUE;
     333              :     }
     334              :     else
     335              :     {
     336            0 :         ctx->state = OAUTH_STATE_FINISHED;
     337            0 :         status = PG_SASL_EXCHANGE_SUCCESS;
     338              :     }
     339              : 
     340              :     /* Don't let extra copies of the bearer token hang around. */
     341            0 :     explicit_bzero(input_copy, inputlen);
     342              : 
     343            0 :     return status;
     344              : }
     345              : 
     346              : /*
     347              :  * Convert an arbitrary byte to printable form.  For error messages.
     348              :  *
     349              :  * If it's a printable ASCII character, print it as a single character.
     350              :  * otherwise, print it in hex.
     351              :  *
     352              :  * The returned pointer points to a static buffer.
     353              :  */
     354              : static char *
     355            0 : sanitize_char(char c)
     356              : {
     357              :     static char buf[5];
     358              : 
     359            0 :     if (c >= 0x21 && c <= 0x7E)
     360            0 :         snprintf(buf, sizeof(buf), "'%c'", c);
     361              :     else
     362            0 :         snprintf(buf, sizeof(buf), "0x%02x", (unsigned char) c);
     363            0 :     return buf;
     364              : }
     365              : 
     366              : /*
     367              :  * Performs syntactic validation of a key and value from the initial client
     368              :  * response. (Semantic validation of interesting values must be performed
     369              :  * later.)
     370              :  */
     371              : static void
     372            0 : validate_kvpair(const char *key, const char *val)
     373              : {
     374              :     /*-----
     375              :      * From Sec 3.1:
     376              :      *     key            = 1*(ALPHA)
     377              :      */
     378              :     static const char *key_allowed_set =
     379              :         "abcdefghijklmnopqrstuvwxyz"
     380              :         "ABCDEFGHIJKLMNOPQRSTUVWXYZ";
     381              : 
     382              :     size_t      span;
     383              : 
     384            0 :     if (!key[0])
     385            0 :         ereport(ERROR,
     386              :                 errcode(ERRCODE_PROTOCOL_VIOLATION),
     387              :                 errmsg("malformed OAUTHBEARER message"),
     388              :                 errdetail("Message contains an empty key name."));
     389              : 
     390            0 :     span = strspn(key, key_allowed_set);
     391            0 :     if (key[span] != '\0')
     392            0 :         ereport(ERROR,
     393              :                 errcode(ERRCODE_PROTOCOL_VIOLATION),
     394              :                 errmsg("malformed OAUTHBEARER message"),
     395              :                 errdetail("Message contains an invalid key name."));
     396              : 
     397              :     /*-----
     398              :      * From Sec 3.1:
     399              :      *     value          = *(VCHAR / SP / HTAB / CR / LF )
     400              :      *
     401              :      * The VCHAR (visible character) class is large; a loop is more
     402              :      * straightforward than strspn().
     403              :      */
     404            0 :     for (; *val; ++val)
     405              :     {
     406            0 :         if (0x21 <= *val && *val <= 0x7E)
     407            0 :             continue;           /* VCHAR */
     408              : 
     409            0 :         switch (*val)
     410              :         {
     411            0 :             case ' ':
     412              :             case '\t':
     413              :             case '\r':
     414              :             case '\n':
     415            0 :                 continue;       /* SP, HTAB, CR, LF */
     416              : 
     417            0 :             default:
     418            0 :                 ereport(ERROR,
     419              :                         errcode(ERRCODE_PROTOCOL_VIOLATION),
     420              :                         errmsg("malformed OAUTHBEARER message"),
     421              :                         errdetail("Message contains an invalid value."));
     422              :         }
     423              :     }
     424            0 : }
     425              : 
     426              : /*
     427              :  * Consumes all kvpairs in an OAUTHBEARER exchange message. If the "auth" key is
     428              :  * found, its value is returned.
     429              :  */
     430              : static char *
     431            0 : parse_kvpairs_for_auth(char **input)
     432              : {
     433            0 :     char       *pos = *input;
     434            0 :     char       *auth = NULL;
     435              : 
     436              :     /*----
     437              :      * The relevant ABNF, from Sec. 3.1:
     438              :      *
     439              :      *     kvsep          = %x01
     440              :      *     key            = 1*(ALPHA)
     441              :      *     value          = *(VCHAR / SP / HTAB / CR / LF )
     442              :      *     kvpair         = key "=" value kvsep
     443              :      *   ;;gs2-header     = See RFC 5801
     444              :      *     client-resp    = (gs2-header kvsep *kvpair kvsep) / kvsep
     445              :      *
     446              :      * By the time we reach this code, the gs2-header and initial kvsep have
     447              :      * already been validated. We start at the beginning of the first kvpair.
     448              :      */
     449              : 
     450            0 :     while (*pos)
     451              :     {
     452              :         char       *end;
     453              :         char       *sep;
     454              :         char       *key;
     455              :         char       *value;
     456              : 
     457              :         /*
     458              :          * Find the end of this kvpair. Note that input is null-terminated by
     459              :          * the SASL code, so the strchr() is bounded.
     460              :          */
     461            0 :         end = strchr(pos, KVSEP);
     462            0 :         if (!end)
     463            0 :             ereport(ERROR,
     464              :                     errcode(ERRCODE_PROTOCOL_VIOLATION),
     465              :                     errmsg("malformed OAUTHBEARER message"),
     466              :                     errdetail("Message contains an unterminated key/value pair."));
     467            0 :         *end = '\0';
     468              : 
     469            0 :         if (pos == end)
     470              :         {
     471              :             /* Empty kvpair, signifying the end of the list. */
     472            0 :             *input = pos + 1;
     473            0 :             return auth;
     474              :         }
     475              : 
     476              :         /*
     477              :          * Find the end of the key name.
     478              :          */
     479            0 :         sep = strchr(pos, '=');
     480            0 :         if (!sep)
     481            0 :             ereport(ERROR,
     482              :                     errcode(ERRCODE_PROTOCOL_VIOLATION),
     483              :                     errmsg("malformed OAUTHBEARER message"),
     484              :                     errdetail("Message contains a key without a value."));
     485            0 :         *sep = '\0';
     486              : 
     487              :         /* Both key and value are now safely terminated. */
     488            0 :         key = pos;
     489            0 :         value = sep + 1;
     490            0 :         validate_kvpair(key, value);
     491              : 
     492            0 :         if (strcmp(key, AUTH_KEY) == 0)
     493              :         {
     494            0 :             if (auth)
     495            0 :                 ereport(ERROR,
     496              :                         errcode(ERRCODE_PROTOCOL_VIOLATION),
     497              :                         errmsg("malformed OAUTHBEARER message"),
     498              :                         errdetail("Message contains multiple auth values."));
     499              : 
     500            0 :             auth = value;
     501              :         }
     502              :         else
     503              :         {
     504              :             /*
     505              :              * The RFC also defines the host and port keys, but they are not
     506              :              * required for OAUTHBEARER and we do not use them. Also, per Sec.
     507              :              * 3.1, any key/value pairs we don't recognize must be ignored.
     508              :              */
     509              :         }
     510              : 
     511              :         /* Move to the next pair. */
     512            0 :         pos = end + 1;
     513              :     }
     514              : 
     515            0 :     ereport(ERROR,
     516              :             errcode(ERRCODE_PROTOCOL_VIOLATION),
     517              :             errmsg("malformed OAUTHBEARER message"),
     518              :             errdetail("Message did not contain a final terminator."));
     519              : 
     520              :     pg_unreachable();
     521              :     return NULL;
     522              : }
     523              : 
     524              : /*
     525              :  * Builds the JSON response for failed authentication (RFC 7628, Sec. 3.2.2).
     526              :  * This contains the required scopes for entry and a pointer to the OAuth/OpenID
     527              :  * discovery document, which the client may use to conduct its OAuth flow.
     528              :  */
     529              : static void
     530            0 : generate_error_response(struct oauth_ctx *ctx, char **output, int *outputlen)
     531              : {
     532              :     StringInfoData buf;
     533              :     StringInfoData issuer;
     534              : 
     535              :     /*
     536              :      * The admin needs to set an issuer and scope for OAuth to work. There's
     537              :      * not really a way to hide this from the user, either, because we can't
     538              :      * choose a "default" issuer, so be honest in the failure message. (In
     539              :      * practice such configurations are rejected during HBA parsing.)
     540              :      */
     541            0 :     if (!ctx->issuer || !ctx->scope)
     542            0 :         ereport(FATAL,
     543              :                 errcode(ERRCODE_INTERNAL_ERROR),
     544              :                 errmsg("OAuth is not properly configured for this user"),
     545              :                 errdetail_log("The issuer and scope parameters must be set in pg_hba.conf."));
     546              : 
     547              :     /*
     548              :      * Build a default .well-known URI based on our issuer, unless the HBA has
     549              :      * already provided one.
     550              :      */
     551            0 :     initStringInfo(&issuer);
     552            0 :     appendStringInfoString(&issuer, ctx->issuer);
     553            0 :     if (strstr(ctx->issuer, "/.well-known/") == NULL)
     554            0 :         appendStringInfoString(&issuer, "/.well-known/openid-configuration");
     555              : 
     556            0 :     initStringInfo(&buf);
     557              : 
     558              :     /*
     559              :      * Escaping the string here is belt-and-suspenders defensive programming
     560              :      * since escapable characters aren't valid in either the issuer URI or the
     561              :      * scope list, but the HBA doesn't enforce that yet.
     562              :      */
     563            0 :     appendStringInfoString(&buf, "{ \"status\": \"invalid_token\", ");
     564              : 
     565            0 :     appendStringInfoString(&buf, "\"openid-configuration\": ");
     566            0 :     escape_json(&buf, issuer.data);
     567            0 :     pfree(issuer.data);
     568              : 
     569            0 :     appendStringInfoString(&buf, ", \"scope\": ");
     570            0 :     escape_json(&buf, ctx->scope);
     571              : 
     572            0 :     appendStringInfoString(&buf, " }");
     573              : 
     574            0 :     *output = buf.data;
     575            0 :     *outputlen = buf.len;
     576            0 : }
     577              : 
     578              : /*-----
     579              :  * Validates the provided Authorization header and returns the token from
     580              :  * within it. NULL is returned on validation failure.
     581              :  *
     582              :  * Only Bearer tokens are accepted. The ABNF is defined in RFC 6750, Sec.
     583              :  * 2.1:
     584              :  *
     585              :  *      b64token    = 1*( ALPHA / DIGIT /
     586              :  *                        "-" / "." / "_" / "~" / "+" / "/" ) *"="
     587              :  *      credentials = "Bearer" 1*SP b64token
     588              :  *
     589              :  * The "credentials" construction is what we receive in our auth value.
     590              :  *
     591              :  * Since that spec is subordinate to HTTP (i.e. the HTTP Authorization
     592              :  * header format; RFC 9110 Sec. 11), the "Bearer" scheme string must be
     593              :  * compared case-insensitively. (This is not mentioned in RFC 6750, but the
     594              :  * OAUTHBEARER spec points it out: RFC 7628 Sec. 4.)
     595              :  *
     596              :  * Invalid formats are technically a protocol violation, but we shouldn't
     597              :  * reflect any information about the sensitive Bearer token back to the
     598              :  * client; log at COMMERROR instead.
     599              :  */
     600              : static const char *
     601            0 : validate_token_format(const char *header)
     602              : {
     603              :     size_t      span;
     604              :     const char *token;
     605              :     static const char *const b64token_allowed_set =
     606              :         "abcdefghijklmnopqrstuvwxyz"
     607              :         "ABCDEFGHIJKLMNOPQRSTUVWXYZ"
     608              :         "0123456789-._~+/";
     609              : 
     610              :     /* Missing auth headers should be handled by the caller. */
     611              :     Assert(header);
     612              :     /* Empty auth (discovery) should be handled before calling validate(). */
     613              :     Assert(header[0] != '\0');
     614              : 
     615            0 :     if (pg_strncasecmp(header, BEARER_SCHEME, strlen(BEARER_SCHEME)))
     616              :     {
     617            0 :         ereport(COMMERROR,
     618              :                 errcode(ERRCODE_PROTOCOL_VIOLATION),
     619              :                 errmsg("malformed OAuth bearer token"),
     620              :                 errdetail_log("Client response indicated a non-Bearer authentication scheme."));
     621            0 :         return NULL;
     622              :     }
     623              : 
     624              :     /* Pull the bearer token out of the auth value. */
     625            0 :     token = header + strlen(BEARER_SCHEME);
     626              : 
     627              :     /* Swallow any additional spaces. */
     628            0 :     while (*token == ' ')
     629            0 :         token++;
     630              : 
     631              :     /* Tokens must not be empty. */
     632            0 :     if (!*token)
     633              :     {
     634            0 :         ereport(COMMERROR,
     635              :                 errcode(ERRCODE_PROTOCOL_VIOLATION),
     636              :                 errmsg("malformed OAuth bearer token"),
     637              :                 errdetail_log("Bearer token is empty."));
     638            0 :         return NULL;
     639              :     }
     640              : 
     641              :     /*
     642              :      * Make sure the token contains only allowed characters. Tokens may end
     643              :      * with any number of '=' characters.
     644              :      */
     645            0 :     span = strspn(token, b64token_allowed_set);
     646            0 :     while (token[span] == '=')
     647            0 :         span++;
     648              : 
     649            0 :     if (token[span] != '\0')
     650              :     {
     651              :         /*
     652              :          * This error message could be more helpful by printing the
     653              :          * problematic character(s), but that'd be a bit like printing a piece
     654              :          * of someone's password into the logs.
     655              :          */
     656            0 :         ereport(COMMERROR,
     657              :                 errcode(ERRCODE_PROTOCOL_VIOLATION),
     658              :                 errmsg("malformed OAuth bearer token"),
     659              :                 errdetail_log("Bearer token is not in the correct format."));
     660            0 :         return NULL;
     661              :     }
     662              : 
     663            0 :     return token;
     664              : }
     665              : 
     666              : /*
     667              :  * Checks that the "auth" kvpair in the client response contains a syntactically
     668              :  * valid Bearer token, then passes it along to the loaded validator module for
     669              :  * authorization. Returns true if validation succeeds.
     670              :  */
     671              : static bool
     672            0 : validate(Port *port, const char *auth, const char **logdetail)
     673              : {
     674              :     int         map_status;
     675              :     ValidatorModuleResult *ret;
     676              :     const char *token;
     677              :     bool        status;
     678              : 
     679              :     /* Ensure that we have a correct token to validate */
     680            0 :     if (!(token = validate_token_format(auth)))
     681            0 :         return false;
     682              : 
     683              :     /*
     684              :      * Ensure that we have a validation library loaded, this should always be
     685              :      * the case and an error here is indicative of a bug.
     686              :      */
     687            0 :     if (!ValidatorCallbacks || !ValidatorCallbacks->validate_cb)
     688            0 :         ereport(FATAL,
     689              :                 errcode(ERRCODE_INTERNAL_ERROR),
     690              :                 errmsg("validation of OAuth token requested without a validator loaded"));
     691              : 
     692              :     /* Call the validation function from the validator module */
     693            0 :     ret = palloc0_object(ValidatorModuleResult);
     694            0 :     if (!ValidatorCallbacks->validate_cb(validator_module_state, token,
     695            0 :                                          port->user_name, ret))
     696              :     {
     697            0 :         ereport(WARNING,
     698              :                 errcode(ERRCODE_INTERNAL_ERROR),
     699              :                 errmsg("internal error in OAuth validator module"),
     700              :                 ret->error_detail ? errdetail_log("%s", ret->error_detail) : 0);
     701              : 
     702            0 :         *logdetail = ret->error_detail;
     703            0 :         return false;
     704              :     }
     705              : 
     706              :     /*
     707              :      * Log any authentication results even if the token isn't authorized; it
     708              :      * might be useful for auditing or troubleshooting.
     709              :      */
     710            0 :     if (ret->authn_id)
     711            0 :         set_authn_id(port, ret->authn_id);
     712              : 
     713            0 :     if (!ret->authorized)
     714              :     {
     715            0 :         if (ret->error_detail)
     716            0 :             *logdetail = ret->error_detail;
     717              :         else
     718            0 :             *logdetail = _("Validator failed to authorize the provided token.");
     719              : 
     720            0 :         status = false;
     721            0 :         goto cleanup;
     722              :     }
     723              : 
     724            0 :     if (port->hba->oauth_skip_usermap)
     725              :     {
     726              :         /*
     727              :          * If the validator is our authorization authority, we're done.
     728              :          * Authentication may or may not have been performed depending on the
     729              :          * validator implementation; all that matters is that the validator
     730              :          * says the user can log in with the target role.
     731              :          */
     732            0 :         status = true;
     733            0 :         goto cleanup;
     734              :     }
     735              : 
     736              :     /* Make sure the validator authenticated the user. */
     737            0 :     if (ret->authn_id == NULL || ret->authn_id[0] == '\0')
     738              :     {
     739            0 :         *logdetail = _("Validator provided no identity.");
     740              : 
     741            0 :         status = false;
     742            0 :         goto cleanup;
     743              :     }
     744              : 
     745              :     /* Finally, check the user map. */
     746            0 :     map_status = check_usermap(port->hba->usermap, port->user_name,
     747              :                                MyClientConnectionInfo.authn_id, false);
     748            0 :     status = (map_status == STATUS_OK);
     749              : 
     750            0 : cleanup:
     751              : 
     752              :     /*
     753              :      * Clear and free the validation result from the validator module once
     754              :      * we're done with it.
     755              :      */
     756            0 :     if (ret->authn_id != NULL)
     757            0 :         pfree(ret->authn_id);
     758            0 :     pfree(ret);
     759              : 
     760            0 :     return status;
     761              : }
     762              : 
     763              : /*
     764              :  * load_validator_library
     765              :  *
     766              :  * Load the configured validator library in order to perform token validation.
     767              :  * There is no built-in fallback since validation is implementation specific. If
     768              :  * no validator library is configured, or if it fails to load, then error out
     769              :  * since token validation won't be possible.
     770              :  */
     771              : static void
     772            0 : load_validator_library(const char *libname)
     773              : {
     774              :     OAuthValidatorModuleInit validator_init;
     775              :     MemoryContextCallback *mcb;
     776              : 
     777              :     /*
     778              :      * The presence, and validity, of libname has already been established by
     779              :      * check_oauth_validator so we don't need to perform more than Assert
     780              :      * level checking here.
     781              :      */
     782              :     Assert(libname && *libname);
     783              : 
     784            0 :     validator_init = (OAuthValidatorModuleInit)
     785            0 :         load_external_function(libname, "_PG_oauth_validator_module_init",
     786              :                                false, NULL);
     787              : 
     788              :     /*
     789              :      * The validator init function is required since it will set the callbacks
     790              :      * for the validator library.
     791              :      */
     792            0 :     if (validator_init == NULL)
     793            0 :         ereport(ERROR,
     794              :                 errmsg("%s module \"%s\" must define the symbol %s",
     795              :                        "OAuth validator", libname, "_PG_oauth_validator_module_init"));
     796              : 
     797            0 :     ValidatorCallbacks = (*validator_init) ();
     798              :     Assert(ValidatorCallbacks);
     799              : 
     800              :     /*
     801              :      * Check the magic number, to protect against break-glass scenarios where
     802              :      * the ABI must change within a major version. load_external_function()
     803              :      * already checks for compatibility across major versions.
     804              :      */
     805            0 :     if (ValidatorCallbacks->magic != PG_OAUTH_VALIDATOR_MAGIC)
     806            0 :         ereport(ERROR,
     807              :                 errmsg("%s module \"%s\": magic number mismatch",
     808              :                        "OAuth validator", libname),
     809              :                 errdetail("Server has magic number 0x%08X, module has 0x%08X.",
     810              :                           PG_OAUTH_VALIDATOR_MAGIC, ValidatorCallbacks->magic));
     811              : 
     812              :     /*
     813              :      * Make sure all required callbacks are present in the ValidatorCallbacks
     814              :      * structure. Right now only the validation callback is required.
     815              :      */
     816            0 :     if (ValidatorCallbacks->validate_cb == NULL)
     817            0 :         ereport(ERROR,
     818              :                 errmsg("%s module \"%s\" must provide a %s callback",
     819              :                        "OAuth validator", libname, "validate_cb"));
     820              : 
     821              :     /* Allocate memory for validator library private state data */
     822            0 :     validator_module_state = palloc0_object(ValidatorModuleState);
     823            0 :     validator_module_state->sversion = PG_VERSION_NUM;
     824              : 
     825            0 :     if (ValidatorCallbacks->startup_cb != NULL)
     826            0 :         ValidatorCallbacks->startup_cb(validator_module_state);
     827              : 
     828              :     /* Shut down the library before cleaning up its state. */
     829            0 :     mcb = palloc0_object(MemoryContextCallback);
     830            0 :     mcb->func = shutdown_validator_library;
     831              : 
     832            0 :     MemoryContextRegisterResetCallback(CurrentMemoryContext, mcb);
     833            0 : }
     834              : 
     835              : /*
     836              :  * Call the validator module's shutdown callback, if one is provided. This is
     837              :  * invoked during memory context reset.
     838              :  */
     839              : static void
     840            0 : shutdown_validator_library(void *arg)
     841              : {
     842            0 :     if (ValidatorCallbacks->shutdown_cb != NULL)
     843            0 :         ValidatorCallbacks->shutdown_cb(validator_module_state);
     844              : 
     845              :     /* The backing memory for this is about to disappear. */
     846            0 :     ValidatorOptions = NIL;
     847            0 : }
     848              : 
     849              : /*
     850              :  * Ensure an OAuth validator named in the HBA is permitted by the configuration.
     851              :  *
     852              :  * If the validator is currently unset and exactly one library is declared in
     853              :  * oauth_validator_libraries, then that library will be used as the validator.
     854              :  * Otherwise the name must be present in the list of oauth_validator_libraries.
     855              :  */
     856              : bool
     857            0 : check_oauth_validator(HbaLine *hbaline, int elevel, char **err_msg)
     858              : {
     859            0 :     int         line_num = hbaline->linenumber;
     860            0 :     const char *file_name = hbaline->sourcefile;
     861              :     char       *rawstring;
     862            0 :     List       *elemlist = NIL;
     863              : 
     864            0 :     *err_msg = NULL;
     865              : 
     866            0 :     if (oauth_validator_libraries_string[0] == '\0')
     867              :     {
     868            0 :         ereport(elevel,
     869              :                 errcode(ERRCODE_CONFIG_FILE_ERROR),
     870              :                 errmsg("oauth_validator_libraries must be set for authentication method %s",
     871              :                        "oauth"),
     872              :                 errcontext("line %d of configuration file \"%s\"",
     873              :                            line_num, file_name));
     874            0 :         *err_msg = psprintf("oauth_validator_libraries must be set for authentication method %s",
     875              :                             "oauth");
     876            0 :         return false;
     877              :     }
     878              : 
     879              :     /* SplitDirectoriesString needs a modifiable copy */
     880            0 :     rawstring = pstrdup(oauth_validator_libraries_string);
     881              : 
     882            0 :     if (!SplitDirectoriesString(rawstring, ',', &elemlist))
     883              :     {
     884              :         /* syntax error in list */
     885            0 :         ereport(elevel,
     886              :                 errcode(ERRCODE_CONFIG_FILE_ERROR),
     887              :                 errmsg("invalid list syntax in parameter \"%s\"",
     888              :                        "oauth_validator_libraries"));
     889            0 :         *err_msg = psprintf("invalid list syntax in parameter \"%s\"",
     890              :                             "oauth_validator_libraries");
     891            0 :         goto done;
     892              :     }
     893              : 
     894            0 :     if (!hbaline->oauth_validator)
     895              :     {
     896            0 :         if (elemlist->length == 1)
     897              :         {
     898            0 :             hbaline->oauth_validator = pstrdup(linitial(elemlist));
     899            0 :             goto done;
     900              :         }
     901              : 
     902            0 :         ereport(elevel,
     903              :                 errcode(ERRCODE_CONFIG_FILE_ERROR),
     904              :                 errmsg("authentication method \"oauth\" requires argument \"validator\" to be set when oauth_validator_libraries contains multiple options"),
     905              :                 errcontext("line %d of configuration file \"%s\"",
     906              :                            line_num, file_name));
     907            0 :         *err_msg = "authentication method \"oauth\" requires argument \"validator\" to be set when oauth_validator_libraries contains multiple options";
     908            0 :         goto done;
     909              :     }
     910              : 
     911            0 :     foreach_ptr(char, allowed, elemlist)
     912              :     {
     913            0 :         if (strcmp(allowed, hbaline->oauth_validator) == 0)
     914            0 :             goto done;
     915              :     }
     916              : 
     917            0 :     ereport(elevel,
     918              :             errcode(ERRCODE_INVALID_PARAMETER_VALUE),
     919              :             errmsg("validator \"%s\" is not permitted by %s",
     920              :                    hbaline->oauth_validator, "oauth_validator_libraries"),
     921              :             errcontext("line %d of configuration file \"%s\"",
     922              :                        line_num, file_name));
     923            0 :     *err_msg = psprintf("validator \"%s\" is not permitted by %s",
     924              :                         hbaline->oauth_validator, "oauth_validator_libraries");
     925              : 
     926            0 : done:
     927            0 :     list_free_deep(elemlist);
     928            0 :     pfree(rawstring);
     929              : 
     930            0 :     return (*err_msg == NULL);
     931              : }
     932              : 
     933              : /*
     934              :  * Client APIs for validator implementations
     935              :  *
     936              :  * Since we're currently not threaded, we only allow one validator in the
     937              :  * process at a time. So we can make use of globals for now instead of looking
     938              :  * up information using the state pointer. We probably shouldn't assume that the
     939              :  * module hasn't temporarily changed memory contexts on us, though; functions
     940              :  * here should defensively use an appropriate context when making global
     941              :  * allocations.
     942              :  */
     943              : 
     944              : /*
     945              :  * Adds to the list of allowed validator.* HBA options. Used during the
     946              :  * startup_cb.
     947              :  */
     948              : void
     949            0 : RegisterOAuthHBAOptions(ValidatorModuleState *state, int num,
     950              :                         const char *opts[])
     951              : {
     952              :     MemoryContext oldcontext;
     953              : 
     954            0 :     if (!state)
     955              :     {
     956              :         Assert(false);
     957            0 :         return;
     958              :     }
     959              : 
     960            0 :     oldcontext = MemoryContextSwitchTo(ValidatorMemoryContext);
     961              : 
     962            0 :     for (int i = 0; i < num; i++)
     963              :     {
     964            0 :         if (!valid_oauth_hba_option_name(opts[i]))
     965              :         {
     966              :             /*
     967              :              * The user can't set this option in the HBA, so GetOAuthHBAOption
     968              :              * would always return NULL.
     969              :              */
     970            0 :             ereport(WARNING,
     971              :                     errmsg("HBA option name \"%s\" is invalid and will be ignored",
     972              :                            opts[i]),
     973              :             /* translator: the second %s is a function name */
     974              :                     errcontext("validator module \"%s\", in call to %s",
     975              :                                MyProcPort->hba->oauth_validator,
     976              :                                "RegisterOAuthHBAOptions"));
     977            0 :             continue;
     978              :         }
     979              : 
     980            0 :         ValidatorOptions = lappend(ValidatorOptions, pstrdup(opts[i]));
     981              :     }
     982              : 
     983            0 :     MemoryContextSwitchTo(oldcontext);
     984              : 
     985              :     /*
     986              :      * Wait to validate the HBA against the registered options until later
     987              :      * (see check_validator_hba_options()).
     988              :      *
     989              :      * Delaying allows the validator to make multiple registration calls, to
     990              :      * append to the list; it lets us make the check in a place where we can
     991              :      * report the error without leaking details to the client; and it avoids
     992              :      * exporting the order of operations between HBA matching and the
     993              :      * startup_cb call as an API guarantee. (The last issue may become
     994              :      * relevant with a threaded model.)
     995              :      */
     996              : }
     997              : 
     998              : /*
     999              :  * Restrict the names available to custom HBA options, so that we don't
    1000              :  * accidentally prevent future syntax extensions to HBA files.
    1001              :  */
    1002              : bool
    1003            0 : valid_oauth_hba_option_name(const char *name)
    1004              : {
    1005              :     /*
    1006              :      * This list is not incredibly principled, since the goal is just to bound
    1007              :      * compatibility guarantees for our HBA parser. Alphanumerics seem
    1008              :      * obviously fine, and it's difficult to argue against the punctuation
    1009              :      * that's already included in some HBA option names and identifiers.
    1010              :      */
    1011              :     static const char *name_allowed_set =
    1012              :         "abcdefghijklmnopqrstuvwxyz"
    1013              :         "ABCDEFGHIJKLMNOPQRSTUVWXYZ"
    1014              :         "0123456789_-";
    1015              : 
    1016              :     size_t      span;
    1017              : 
    1018            0 :     if (!name[0])
    1019            0 :         return false;
    1020              : 
    1021            0 :     span = strspn(name, name_allowed_set);
    1022            0 :     return name[span] == '\0';
    1023              : }
    1024              : 
    1025              : /*
    1026              :  * Verifies that all validator.* HBA options specified by the user were actually
    1027              :  * registered by the validator library in use.
    1028              :  */
    1029              : static bool
    1030            0 : check_validator_hba_options(Port *port, const char **logdetail)
    1031              : {
    1032            0 :     HbaLine    *hba = port->hba;
    1033              : 
    1034            0 :     foreach_ptr(char, key, hba->oauth_opt_keys)
    1035              :     {
    1036            0 :         bool        found = false;
    1037              : 
    1038              :         /* O(n^2) shouldn't be a problem here in practice. */
    1039            0 :         foreach_ptr(char, optname, ValidatorOptions)
    1040              :         {
    1041            0 :             if (strcmp(key, optname) == 0)
    1042              :             {
    1043            0 :                 found = true;
    1044            0 :                 break;
    1045              :             }
    1046              :         }
    1047              : 
    1048            0 :         if (!found)
    1049              :         {
    1050              :             /*
    1051              :              * Unknown option name. Mirror the error messages in hba.c here,
    1052              :              * keeping in mind that the original "validator." prefix was
    1053              :              * stripped from the key during parsing.
    1054              :              *
    1055              :              * Since this is affecting live connections, which is unusual for
    1056              :              * HBA, be noisy with a WARNING. (Warnings aren't sent to clients
    1057              :              * prior to successful authentication, so this won't disclose the
    1058              :              * server config.) It'll duplicate some of the information in the
    1059              :              * logdetail, but that should make it hard to miss the connection
    1060              :              * between the two.
    1061              :              */
    1062            0 :             char       *name = psprintf("validator.%s", key);
    1063              : 
    1064            0 :             *logdetail = psprintf(_("unrecognized authentication option name: \"%s\""),
    1065              :                                   name);
    1066            0 :             ereport(WARNING,
    1067              :                     errcode(ERRCODE_CONFIG_FILE_ERROR),
    1068              :                     errmsg("unrecognized authentication option name: \"%s\"",
    1069              :                            name),
    1070              :             /* translator: the first %s is the name of the module */
    1071              :                     errdetail("The installed validator module (\"%s\") did not define an option named \"%s\".",
    1072              :                               hba->oauth_validator, key),
    1073              :                     errhint("All OAuth connections matching this line will fail. Correct the option and reload the server configuration."),
    1074              :                     errcontext("line %d of configuration file \"%s\"",
    1075              :                                hba->linenumber, hba->sourcefile));
    1076              : 
    1077            0 :             return false;
    1078              :         }
    1079              :     }
    1080              : 
    1081            0 :     ValidatorOptionsChecked = true; /* unfetter GetOAuthHBAOption() */
    1082            0 :     return true;
    1083              : }
    1084              : 
    1085              : /*
    1086              :  * Retrieves the setting for a validator.* HBA option, or NULL if not found.
    1087              :  * This may only be used during the validate_cb and shutdown_cb.
    1088              :  */
    1089              : const char *
    1090            0 : GetOAuthHBAOption(const ValidatorModuleState *state, const char *optname)
    1091              : {
    1092            0 :     HbaLine    *hba = MyProcPort->hba;
    1093              :     ListCell   *lc_k;
    1094              :     ListCell   *lc_v;
    1095            0 :     const char *ret = NULL;
    1096              : 
    1097            0 :     if (!ValidatorOptionsChecked)
    1098              :     {
    1099              :         /*
    1100              :          * Prevent the startup_cb from retrieving HBA options that it has just
    1101              :          * registered. This probably seems strange -- why refuse to hand out
    1102              :          * information we already know? -- but this lets us reserve the
    1103              :          * ability to perform the startup_cb call earlier, before we know
    1104              :          * which HBA line is matched by a connection, without breaking this
    1105              :          * API.
    1106              :          */
    1107            0 :         return NULL;
    1108              :     }
    1109              : 
    1110            0 :     if (!state || !hba)
    1111              :     {
    1112              :         Assert(false);
    1113            0 :         return NULL;
    1114              :     }
    1115              : 
    1116              :     Assert(list_length(hba->oauth_opt_keys) == list_length(hba->oauth_opt_vals));
    1117              : 
    1118            0 :     forboth(lc_k, hba->oauth_opt_keys, lc_v, hba->oauth_opt_vals)
    1119              :     {
    1120            0 :         const char *key = lfirst(lc_k);
    1121            0 :         const char *val = lfirst(lc_v);
    1122              : 
    1123            0 :         if (strcmp(key, optname) == 0)
    1124              :         {
    1125              :             /*
    1126              :              * Don't return yet -- when regular HBA options are specified more
    1127              :              * than once, the last one wins. Do the same for these options.
    1128              :              */
    1129            0 :             ret = val;
    1130              :         }
    1131              :     }
    1132              : 
    1133            0 :     return ret;
    1134              : }
        

Generated by: LCOV version 2.0-1