LCOV - code coverage report
Current view: top level - src/backend/libpq - auth-oauth.c (source / functions) Coverage Total Hit
Test: PostgreSQL 19devel Lines: 0.0 % 242 0
Test Date: 2026-04-07 14:16:30 Functions: 0.0 % 12 0
Legend: Lines:     hit not hit

            Line data    Source code
       1              : /*-------------------------------------------------------------------------
       2              :  *
       3              :  * auth-oauth.c
       4              :  *    Server-side implementation of the SASL OAUTHBEARER mechanism.
       5              :  *
       6              :  * See the following RFC for more details:
       7              :  * - RFC 7628: https://datatracker.ietf.org/doc/html/rfc7628
       8              :  *
       9              :  * Portions Copyright (c) 1996-2026, PostgreSQL Global Development Group
      10              :  * Portions Copyright (c) 1994, Regents of the University of California
      11              :  *
      12              :  * src/backend/libpq/auth-oauth.c
      13              :  *
      14              :  *-------------------------------------------------------------------------
      15              :  */
      16              : #include "postgres.h"
      17              : 
      18              : #include <unistd.h>
      19              : #include <fcntl.h>
      20              : 
      21              : #include "common/oauth-common.h"
      22              : #include "fmgr.h"
      23              : #include "lib/stringinfo.h"
      24              : #include "libpq/auth.h"
      25              : #include "libpq/hba.h"
      26              : #include "libpq/oauth.h"
      27              : #include "libpq/sasl.h"
      28              : #include "storage/fd.h"
      29              : #include "storage/ipc.h"
      30              : #include "utils/json.h"
      31              : #include "utils/varlena.h"
      32              : 
      33              : /* GUC */
      34              : char       *oauth_validator_libraries_string = NULL;
      35              : 
      36              : static void oauth_get_mechanisms(Port *port, StringInfo buf);
      37              : static void *oauth_init(Port *port, const char *selected_mech, const char *shadow_pass);
      38              : static int  oauth_exchange(void *opaq, const char *input, int inputlen,
      39              :                            char **output, int *outputlen, const char **logdetail);
      40              : 
      41              : static void load_validator_library(const char *libname);
      42              : static void shutdown_validator_library(void *arg);
      43              : 
      44              : static ValidatorModuleState *validator_module_state;
      45              : static const OAuthValidatorCallbacks *ValidatorCallbacks;
      46              : 
      47              : /* Mechanism declaration */
      48              : const pg_be_sasl_mech pg_be_oauth_mech = {
      49              :     .get_mechanisms = oauth_get_mechanisms,
      50              :     .init = oauth_init,
      51              :     .exchange = oauth_exchange,
      52              : 
      53              :     .max_message_length = PG_MAX_AUTH_TOKEN_LENGTH,
      54              : };
      55              : 
      56              : /* Valid states for the oauth_exchange() machine. */
      57              : enum oauth_state
      58              : {
      59              :     OAUTH_STATE_INIT = 0,
      60              :     OAUTH_STATE_ERROR,
      61              :     OAUTH_STATE_ERROR_DISCOVERY,
      62              :     OAUTH_STATE_FINISHED,
      63              : };
      64              : 
      65              : /* Mechanism callback state. */
      66              : struct oauth_ctx
      67              : {
      68              :     enum oauth_state state;
      69              :     Port       *port;
      70              :     const char *issuer;
      71              :     const char *scope;
      72              : };
      73              : 
      74              : static char *sanitize_char(char c);
      75              : static char *parse_kvpairs_for_auth(char **input);
      76              : static void generate_error_response(struct oauth_ctx *ctx, char **output, int *outputlen);
      77              : static bool validate(Port *port, const char *auth, const char **logdetail);
      78              : 
      79              : /* Constants seen in an OAUTHBEARER client initial response. */
      80              : #define KVSEP 0x01              /* separator byte for key/value pairs */
      81              : #define AUTH_KEY "auth"           /* key containing the Authorization header */
      82              : #define BEARER_SCHEME "Bearer " /* required header scheme (case-insensitive!) */
      83              : 
      84              : /*
      85              :  * Retrieves the OAUTHBEARER mechanism list (currently a single item).
      86              :  *
      87              :  * For a full description of the API, see libpq/sasl.h.
      88              :  */
      89              : static void
      90            0 : oauth_get_mechanisms(Port *port, StringInfo buf)
      91              : {
      92              :     /* Only OAUTHBEARER is supported. */
      93            0 :     appendStringInfoString(buf, OAUTHBEARER_NAME);
      94            0 :     appendStringInfoChar(buf, '\0');
      95            0 : }
      96              : 
      97              : /*
      98              :  * Initializes mechanism state and loads the configured validator module.
      99              :  *
     100              :  * For a full description of the API, see libpq/sasl.h.
     101              :  */
     102              : static void *
     103            0 : oauth_init(Port *port, const char *selected_mech, const char *shadow_pass)
     104              : {
     105              :     struct oauth_ctx *ctx;
     106              : 
     107            0 :     if (strcmp(selected_mech, OAUTHBEARER_NAME) != 0)
     108            0 :         ereport(ERROR,
     109              :                 errcode(ERRCODE_PROTOCOL_VIOLATION),
     110              :                 errmsg("client selected an invalid SASL authentication mechanism"));
     111              : 
     112            0 :     ctx = palloc0_object(struct oauth_ctx);
     113              : 
     114            0 :     ctx->state = OAUTH_STATE_INIT;
     115            0 :     ctx->port = port;
     116              : 
     117              :     Assert(port->hba);
     118            0 :     ctx->issuer = port->hba->oauth_issuer;
     119            0 :     ctx->scope = port->hba->oauth_scope;
     120              : 
     121            0 :     load_validator_library(port->hba->oauth_validator);
     122              : 
     123            0 :     return ctx;
     124              : }
     125              : 
     126              : /*
     127              :  * Implements the OAUTHBEARER SASL exchange (RFC 7628, Sec. 3.2). This pulls
     128              :  * apart the client initial response and validates the Bearer token. It also
     129              :  * handles the dummy error response for a failed handshake, as described in
     130              :  * Sec. 3.2.3.
     131              :  *
     132              :  * For a full description of the API, see libpq/sasl.h.
     133              :  */
     134              : static int
     135            0 : oauth_exchange(void *opaq, const char *input, int inputlen,
     136              :                char **output, int *outputlen, const char **logdetail)
     137              : {
     138              :     char       *input_copy;
     139              :     char       *p;
     140              :     char        cbind_flag;
     141              :     char       *auth;
     142              :     int         status;
     143              : 
     144            0 :     struct oauth_ctx *ctx = opaq;
     145              : 
     146            0 :     *output = NULL;
     147            0 :     *outputlen = -1;
     148              : 
     149              :     /*
     150              :      * If the client didn't include an "Initial Client Response" in the
     151              :      * SASLInitialResponse message, send an empty challenge, to which the
     152              :      * client will respond with the same data that usually comes in the
     153              :      * Initial Client Response.
     154              :      */
     155            0 :     if (input == NULL)
     156              :     {
     157              :         Assert(ctx->state == OAUTH_STATE_INIT);
     158              : 
     159            0 :         *output = pstrdup("");
     160            0 :         *outputlen = 0;
     161            0 :         return PG_SASL_EXCHANGE_CONTINUE;
     162              :     }
     163              : 
     164              :     /*
     165              :      * Check that the input length agrees with the string length of the input.
     166              :      */
     167            0 :     if (inputlen == 0)
     168            0 :         ereport(ERROR,
     169              :                 errcode(ERRCODE_PROTOCOL_VIOLATION),
     170              :                 errmsg("malformed OAUTHBEARER message"),
     171              :                 errdetail("The message is empty."));
     172            0 :     if (inputlen != strlen(input))
     173            0 :         ereport(ERROR,
     174              :                 errcode(ERRCODE_PROTOCOL_VIOLATION),
     175              :                 errmsg("malformed OAUTHBEARER message"),
     176              :                 errdetail("Message length does not match input length."));
     177              : 
     178            0 :     switch (ctx->state)
     179              :     {
     180            0 :         case OAUTH_STATE_INIT:
     181              :             /* Handle this case below. */
     182            0 :             break;
     183              : 
     184            0 :         case OAUTH_STATE_ERROR:
     185              :         case OAUTH_STATE_ERROR_DISCOVERY:
     186              : 
     187              :             /*
     188              :              * Only one response is valid for the client during authentication
     189              :              * failure: a single kvsep.
     190              :              */
     191            0 :             if (inputlen != 1 || *input != KVSEP)
     192            0 :                 ereport(ERROR,
     193              :                         errcode(ERRCODE_PROTOCOL_VIOLATION),
     194              :                         errmsg("malformed OAUTHBEARER message"),
     195              :                         errdetail("Client did not send a kvsep response."));
     196              : 
     197              :             /*
     198              :              * The (failed) handshake is now complete. Don't report discovery
     199              :              * requests in the server log unless the log level is high enough.
     200              :              */
     201            0 :             if (ctx->state == OAUTH_STATE_ERROR_DISCOVERY)
     202              :             {
     203            0 :                 ereport(DEBUG1, errmsg("OAuth issuer discovery requested"));
     204              : 
     205            0 :                 ctx->state = OAUTH_STATE_FINISHED;
     206            0 :                 return PG_SASL_EXCHANGE_ABANDONED;
     207              :             }
     208              : 
     209              :             /* We're not in discovery, so this is just a normal auth failure. */
     210            0 :             ctx->state = OAUTH_STATE_FINISHED;
     211            0 :             return PG_SASL_EXCHANGE_FAILURE;
     212              : 
     213            0 :         default:
     214            0 :             elog(ERROR, "invalid OAUTHBEARER exchange state");
     215              :             return PG_SASL_EXCHANGE_FAILURE;
     216              :     }
     217              : 
     218              :     /* Handle the client's initial message. */
     219            0 :     p = input_copy = pstrdup(input);
     220              : 
     221              :     /*
     222              :      * OAUTHBEARER does not currently define a channel binding (so there is no
     223              :      * OAUTHBEARER-PLUS, and we do not accept a 'p' specifier). We accept a
     224              :      * 'y' specifier purely for the remote chance that a future specification
     225              :      * could define one; then future clients can still interoperate with this
     226              :      * server implementation. 'n' is the expected case.
     227              :      */
     228            0 :     cbind_flag = *p;
     229            0 :     switch (cbind_flag)
     230              :     {
     231            0 :         case 'p':
     232            0 :             ereport(ERROR,
     233              :                     errcode(ERRCODE_PROTOCOL_VIOLATION),
     234              :                     errmsg("malformed OAUTHBEARER message"),
     235              :                     errdetail("The server does not support channel binding for OAuth, but the client message includes channel binding data."));
     236              :             break;
     237              : 
     238            0 :         case 'y':               /* fall through */
     239              :         case 'n':
     240            0 :             p++;
     241            0 :             if (*p != ',')
     242            0 :                 ereport(ERROR,
     243              :                         errcode(ERRCODE_PROTOCOL_VIOLATION),
     244              :                         errmsg("malformed OAUTHBEARER message"),
     245              :                         errdetail("Comma expected, but found character \"%s\".",
     246              :                                   sanitize_char(*p)));
     247            0 :             p++;
     248            0 :             break;
     249              : 
     250            0 :         default:
     251            0 :             ereport(ERROR,
     252              :                     errcode(ERRCODE_PROTOCOL_VIOLATION),
     253              :                     errmsg("malformed OAUTHBEARER message"),
     254              :                     errdetail("Unexpected channel-binding flag \"%s\".",
     255              :                               sanitize_char(cbind_flag)));
     256              :     }
     257              : 
     258              :     /*
     259              :      * Forbid optional authzid (authorization identity).  We don't support it.
     260              :      */
     261            0 :     if (*p == 'a')
     262            0 :         ereport(ERROR,
     263              :                 errcode(ERRCODE_FEATURE_NOT_SUPPORTED),
     264              :                 errmsg("client uses authorization identity, but it is not supported"));
     265            0 :     if (*p != ',')
     266            0 :         ereport(ERROR,
     267              :                 errcode(ERRCODE_PROTOCOL_VIOLATION),
     268              :                 errmsg("malformed OAUTHBEARER message"),
     269              :                 errdetail("Unexpected attribute \"%s\" in client-first-message.",
     270              :                           sanitize_char(*p)));
     271            0 :     p++;
     272              : 
     273              :     /* All remaining fields are separated by the RFC's kvsep (\x01). */
     274            0 :     if (*p != KVSEP)
     275            0 :         ereport(ERROR,
     276              :                 errcode(ERRCODE_PROTOCOL_VIOLATION),
     277              :                 errmsg("malformed OAUTHBEARER message"),
     278              :                 errdetail("Key-value separator expected, but found character \"%s\".",
     279              :                           sanitize_char(*p)));
     280            0 :     p++;
     281              : 
     282            0 :     auth = parse_kvpairs_for_auth(&p);
     283            0 :     if (!auth)
     284            0 :         ereport(ERROR,
     285              :                 errcode(ERRCODE_PROTOCOL_VIOLATION),
     286              :                 errmsg("malformed OAUTHBEARER message"),
     287              :                 errdetail("Message does not contain an auth value."));
     288              : 
     289              :     /* We should be at the end of our message. */
     290            0 :     if (*p)
     291            0 :         ereport(ERROR,
     292              :                 errcode(ERRCODE_PROTOCOL_VIOLATION),
     293              :                 errmsg("malformed OAUTHBEARER message"),
     294              :                 errdetail("Message contains additional data after the final terminator."));
     295              : 
     296            0 :     if (auth[0] == '\0')
     297              :     {
     298              :         /*
     299              :          * An empty auth value represents a discovery request; the client
     300              :          * expects it to fail.  Skip validation entirely and move directly to
     301              :          * the error response.
     302              :          */
     303            0 :         generate_error_response(ctx, output, outputlen);
     304              : 
     305            0 :         ctx->state = OAUTH_STATE_ERROR_DISCOVERY;
     306            0 :         status = PG_SASL_EXCHANGE_CONTINUE;
     307              :     }
     308            0 :     else if (!validate(ctx->port, auth, logdetail))
     309              :     {
     310            0 :         generate_error_response(ctx, output, outputlen);
     311              : 
     312            0 :         ctx->state = OAUTH_STATE_ERROR;
     313            0 :         status = PG_SASL_EXCHANGE_CONTINUE;
     314              :     }
     315              :     else
     316              :     {
     317            0 :         ctx->state = OAUTH_STATE_FINISHED;
     318            0 :         status = PG_SASL_EXCHANGE_SUCCESS;
     319              :     }
     320              : 
     321              :     /* Don't let extra copies of the bearer token hang around. */
     322            0 :     explicit_bzero(input_copy, inputlen);
     323              : 
     324            0 :     return status;
     325              : }
     326              : 
     327              : /*
     328              :  * Convert an arbitrary byte to printable form.  For error messages.
     329              :  *
     330              :  * If it's a printable ASCII character, print it as a single character.
     331              :  * otherwise, print it in hex.
     332              :  *
     333              :  * The returned pointer points to a static buffer.
     334              :  */
     335              : static char *
     336            0 : sanitize_char(char c)
     337              : {
     338              :     static char buf[5];
     339              : 
     340            0 :     if (c >= 0x21 && c <= 0x7E)
     341            0 :         snprintf(buf, sizeof(buf), "'%c'", c);
     342              :     else
     343            0 :         snprintf(buf, sizeof(buf), "0x%02x", (unsigned char) c);
     344            0 :     return buf;
     345              : }
     346              : 
     347              : /*
     348              :  * Performs syntactic validation of a key and value from the initial client
     349              :  * response. (Semantic validation of interesting values must be performed
     350              :  * later.)
     351              :  */
     352              : static void
     353            0 : validate_kvpair(const char *key, const char *val)
     354              : {
     355              :     /*-----
     356              :      * From Sec 3.1:
     357              :      *     key            = 1*(ALPHA)
     358              :      */
     359              :     static const char *key_allowed_set =
     360              :         "abcdefghijklmnopqrstuvwxyz"
     361              :         "ABCDEFGHIJKLMNOPQRSTUVWXYZ";
     362              : 
     363              :     size_t      span;
     364              : 
     365            0 :     if (!key[0])
     366            0 :         ereport(ERROR,
     367              :                 errcode(ERRCODE_PROTOCOL_VIOLATION),
     368              :                 errmsg("malformed OAUTHBEARER message"),
     369              :                 errdetail("Message contains an empty key name."));
     370              : 
     371            0 :     span = strspn(key, key_allowed_set);
     372            0 :     if (key[span] != '\0')
     373            0 :         ereport(ERROR,
     374              :                 errcode(ERRCODE_PROTOCOL_VIOLATION),
     375              :                 errmsg("malformed OAUTHBEARER message"),
     376              :                 errdetail("Message contains an invalid key name."));
     377              : 
     378              :     /*-----
     379              :      * From Sec 3.1:
     380              :      *     value          = *(VCHAR / SP / HTAB / CR / LF )
     381              :      *
     382              :      * The VCHAR (visible character) class is large; a loop is more
     383              :      * straightforward than strspn().
     384              :      */
     385            0 :     for (; *val; ++val)
     386              :     {
     387            0 :         if (0x21 <= *val && *val <= 0x7E)
     388            0 :             continue;           /* VCHAR */
     389              : 
     390            0 :         switch (*val)
     391              :         {
     392            0 :             case ' ':
     393              :             case '\t':
     394              :             case '\r':
     395              :             case '\n':
     396            0 :                 continue;       /* SP, HTAB, CR, LF */
     397              : 
     398            0 :             default:
     399            0 :                 ereport(ERROR,
     400              :                         errcode(ERRCODE_PROTOCOL_VIOLATION),
     401              :                         errmsg("malformed OAUTHBEARER message"),
     402              :                         errdetail("Message contains an invalid value."));
     403              :         }
     404              :     }
     405            0 : }
     406              : 
     407              : /*
     408              :  * Consumes all kvpairs in an OAUTHBEARER exchange message. If the "auth" key is
     409              :  * found, its value is returned.
     410              :  */
     411              : static char *
     412            0 : parse_kvpairs_for_auth(char **input)
     413              : {
     414            0 :     char       *pos = *input;
     415            0 :     char       *auth = NULL;
     416              : 
     417              :     /*----
     418              :      * The relevant ABNF, from Sec. 3.1:
     419              :      *
     420              :      *     kvsep          = %x01
     421              :      *     key            = 1*(ALPHA)
     422              :      *     value          = *(VCHAR / SP / HTAB / CR / LF )
     423              :      *     kvpair         = key "=" value kvsep
     424              :      *   ;;gs2-header     = See RFC 5801
     425              :      *     client-resp    = (gs2-header kvsep *kvpair kvsep) / kvsep
     426              :      *
     427              :      * By the time we reach this code, the gs2-header and initial kvsep have
     428              :      * already been validated. We start at the beginning of the first kvpair.
     429              :      */
     430              : 
     431            0 :     while (*pos)
     432              :     {
     433              :         char       *end;
     434              :         char       *sep;
     435              :         char       *key;
     436              :         char       *value;
     437              : 
     438              :         /*
     439              :          * Find the end of this kvpair. Note that input is null-terminated by
     440              :          * the SASL code, so the strchr() is bounded.
     441              :          */
     442            0 :         end = strchr(pos, KVSEP);
     443            0 :         if (!end)
     444            0 :             ereport(ERROR,
     445              :                     errcode(ERRCODE_PROTOCOL_VIOLATION),
     446              :                     errmsg("malformed OAUTHBEARER message"),
     447              :                     errdetail("Message contains an unterminated key/value pair."));
     448            0 :         *end = '\0';
     449              : 
     450            0 :         if (pos == end)
     451              :         {
     452              :             /* Empty kvpair, signifying the end of the list. */
     453            0 :             *input = pos + 1;
     454            0 :             return auth;
     455              :         }
     456              : 
     457              :         /*
     458              :          * Find the end of the key name.
     459              :          */
     460            0 :         sep = strchr(pos, '=');
     461            0 :         if (!sep)
     462            0 :             ereport(ERROR,
     463              :                     errcode(ERRCODE_PROTOCOL_VIOLATION),
     464              :                     errmsg("malformed OAUTHBEARER message"),
     465              :                     errdetail("Message contains a key without a value."));
     466            0 :         *sep = '\0';
     467              : 
     468              :         /* Both key and value are now safely terminated. */
     469            0 :         key = pos;
     470            0 :         value = sep + 1;
     471            0 :         validate_kvpair(key, value);
     472              : 
     473            0 :         if (strcmp(key, AUTH_KEY) == 0)
     474              :         {
     475            0 :             if (auth)
     476            0 :                 ereport(ERROR,
     477              :                         errcode(ERRCODE_PROTOCOL_VIOLATION),
     478              :                         errmsg("malformed OAUTHBEARER message"),
     479              :                         errdetail("Message contains multiple auth values."));
     480              : 
     481            0 :             auth = value;
     482              :         }
     483              :         else
     484              :         {
     485              :             /*
     486              :              * The RFC also defines the host and port keys, but they are not
     487              :              * required for OAUTHBEARER and we do not use them. Also, per Sec.
     488              :              * 3.1, any key/value pairs we don't recognize must be ignored.
     489              :              */
     490              :         }
     491              : 
     492              :         /* Move to the next pair. */
     493            0 :         pos = end + 1;
     494              :     }
     495              : 
     496            0 :     ereport(ERROR,
     497              :             errcode(ERRCODE_PROTOCOL_VIOLATION),
     498              :             errmsg("malformed OAUTHBEARER message"),
     499              :             errdetail("Message did not contain a final terminator."));
     500              : 
     501              :     pg_unreachable();
     502              :     return NULL;
     503              : }
     504              : 
     505              : /*
     506              :  * Builds the JSON response for failed authentication (RFC 7628, Sec. 3.2.2).
     507              :  * This contains the required scopes for entry and a pointer to the OAuth/OpenID
     508              :  * discovery document, which the client may use to conduct its OAuth flow.
     509              :  */
     510              : static void
     511            0 : generate_error_response(struct oauth_ctx *ctx, char **output, int *outputlen)
     512              : {
     513              :     StringInfoData buf;
     514              :     StringInfoData issuer;
     515              : 
     516              :     /*
     517              :      * The admin needs to set an issuer and scope for OAuth to work. There's
     518              :      * not really a way to hide this from the user, either, because we can't
     519              :      * choose a "default" issuer, so be honest in the failure message. (In
     520              :      * practice such configurations are rejected during HBA parsing.)
     521              :      */
     522            0 :     if (!ctx->issuer || !ctx->scope)
     523            0 :         ereport(FATAL,
     524              :                 errcode(ERRCODE_INTERNAL_ERROR),
     525              :                 errmsg("OAuth is not properly configured for this user"),
     526              :                 errdetail_log("The issuer and scope parameters must be set in pg_hba.conf."));
     527              : 
     528              :     /*
     529              :      * Build a default .well-known URI based on our issuer, unless the HBA has
     530              :      * already provided one.
     531              :      */
     532            0 :     initStringInfo(&issuer);
     533            0 :     appendStringInfoString(&issuer, ctx->issuer);
     534            0 :     if (strstr(ctx->issuer, "/.well-known/") == NULL)
     535            0 :         appendStringInfoString(&issuer, "/.well-known/openid-configuration");
     536              : 
     537            0 :     initStringInfo(&buf);
     538              : 
     539              :     /*
     540              :      * Escaping the string here is belt-and-suspenders defensive programming
     541              :      * since escapable characters aren't valid in either the issuer URI or the
     542              :      * scope list, but the HBA doesn't enforce that yet.
     543              :      */
     544            0 :     appendStringInfoString(&buf, "{ \"status\": \"invalid_token\", ");
     545              : 
     546            0 :     appendStringInfoString(&buf, "\"openid-configuration\": ");
     547            0 :     escape_json(&buf, issuer.data);
     548            0 :     pfree(issuer.data);
     549              : 
     550            0 :     appendStringInfoString(&buf, ", \"scope\": ");
     551            0 :     escape_json(&buf, ctx->scope);
     552              : 
     553            0 :     appendStringInfoString(&buf, " }");
     554              : 
     555            0 :     *output = buf.data;
     556            0 :     *outputlen = buf.len;
     557            0 : }
     558              : 
     559              : /*-----
     560              :  * Validates the provided Authorization header and returns the token from
     561              :  * within it. NULL is returned on validation failure.
     562              :  *
     563              :  * Only Bearer tokens are accepted. The ABNF is defined in RFC 6750, Sec.
     564              :  * 2.1:
     565              :  *
     566              :  *      b64token    = 1*( ALPHA / DIGIT /
     567              :  *                        "-" / "." / "_" / "~" / "+" / "/" ) *"="
     568              :  *      credentials = "Bearer" 1*SP b64token
     569              :  *
     570              :  * The "credentials" construction is what we receive in our auth value.
     571              :  *
     572              :  * Since that spec is subordinate to HTTP (i.e. the HTTP Authorization
     573              :  * header format; RFC 9110 Sec. 11), the "Bearer" scheme string must be
     574              :  * compared case-insensitively. (This is not mentioned in RFC 6750, but the
     575              :  * OAUTHBEARER spec points it out: RFC 7628 Sec. 4.)
     576              :  *
     577              :  * Invalid formats are technically a protocol violation, but we shouldn't
     578              :  * reflect any information about the sensitive Bearer token back to the
     579              :  * client; log at COMMERROR instead.
     580              :  */
     581              : static const char *
     582            0 : validate_token_format(const char *header)
     583              : {
     584              :     size_t      span;
     585              :     const char *token;
     586              :     static const char *const b64token_allowed_set =
     587              :         "abcdefghijklmnopqrstuvwxyz"
     588              :         "ABCDEFGHIJKLMNOPQRSTUVWXYZ"
     589              :         "0123456789-._~+/";
     590              : 
     591              :     /* Missing auth headers should be handled by the caller. */
     592              :     Assert(header);
     593              :     /* Empty auth (discovery) should be handled before calling validate(). */
     594              :     Assert(header[0] != '\0');
     595              : 
     596            0 :     if (pg_strncasecmp(header, BEARER_SCHEME, strlen(BEARER_SCHEME)))
     597              :     {
     598            0 :         ereport(COMMERROR,
     599              :                 errcode(ERRCODE_PROTOCOL_VIOLATION),
     600              :                 errmsg("malformed OAuth bearer token"),
     601              :                 errdetail_log("Client response indicated a non-Bearer authentication scheme."));
     602            0 :         return NULL;
     603              :     }
     604              : 
     605              :     /* Pull the bearer token out of the auth value. */
     606            0 :     token = header + strlen(BEARER_SCHEME);
     607              : 
     608              :     /* Swallow any additional spaces. */
     609            0 :     while (*token == ' ')
     610            0 :         token++;
     611              : 
     612              :     /* Tokens must not be empty. */
     613            0 :     if (!*token)
     614              :     {
     615            0 :         ereport(COMMERROR,
     616              :                 errcode(ERRCODE_PROTOCOL_VIOLATION),
     617              :                 errmsg("malformed OAuth bearer token"),
     618              :                 errdetail_log("Bearer token is empty."));
     619            0 :         return NULL;
     620              :     }
     621              : 
     622              :     /*
     623              :      * Make sure the token contains only allowed characters. Tokens may end
     624              :      * with any number of '=' characters.
     625              :      */
     626            0 :     span = strspn(token, b64token_allowed_set);
     627            0 :     while (token[span] == '=')
     628            0 :         span++;
     629              : 
     630            0 :     if (token[span] != '\0')
     631              :     {
     632              :         /*
     633              :          * This error message could be more helpful by printing the
     634              :          * problematic character(s), but that'd be a bit like printing a piece
     635              :          * of someone's password into the logs.
     636              :          */
     637            0 :         ereport(COMMERROR,
     638              :                 errcode(ERRCODE_PROTOCOL_VIOLATION),
     639              :                 errmsg("malformed OAuth bearer token"),
     640              :                 errdetail_log("Bearer token is not in the correct format."));
     641            0 :         return NULL;
     642              :     }
     643              : 
     644            0 :     return token;
     645              : }
     646              : 
     647              : /*
     648              :  * Checks that the "auth" kvpair in the client response contains a syntactically
     649              :  * valid Bearer token, then passes it along to the loaded validator module for
     650              :  * authorization. Returns true if validation succeeds.
     651              :  */
     652              : static bool
     653            0 : validate(Port *port, const char *auth, const char **logdetail)
     654              : {
     655              :     int         map_status;
     656              :     ValidatorModuleResult *ret;
     657              :     const char *token;
     658              :     bool        status;
     659              : 
     660              :     /* Ensure that we have a correct token to validate */
     661            0 :     if (!(token = validate_token_format(auth)))
     662            0 :         return false;
     663              : 
     664              :     /*
     665              :      * Ensure that we have a validation library loaded, this should always be
     666              :      * the case and an error here is indicative of a bug.
     667              :      */
     668            0 :     if (!ValidatorCallbacks || !ValidatorCallbacks->validate_cb)
     669            0 :         ereport(FATAL,
     670              :                 errcode(ERRCODE_INTERNAL_ERROR),
     671              :                 errmsg("validation of OAuth token requested without a validator loaded"));
     672              : 
     673              :     /* Call the validation function from the validator module */
     674            0 :     ret = palloc0_object(ValidatorModuleResult);
     675            0 :     if (!ValidatorCallbacks->validate_cb(validator_module_state, token,
     676            0 :                                          port->user_name, ret))
     677              :     {
     678            0 :         ereport(WARNING,
     679              :                 errcode(ERRCODE_INTERNAL_ERROR),
     680              :                 errmsg("internal error in OAuth validator module"),
     681              :                 ret->error_detail ? errdetail_log("%s", ret->error_detail) : 0);
     682              : 
     683            0 :         *logdetail = ret->error_detail;
     684            0 :         return false;
     685              :     }
     686              : 
     687              :     /*
     688              :      * Log any authentication results even if the token isn't authorized; it
     689              :      * might be useful for auditing or troubleshooting.
     690              :      */
     691            0 :     if (ret->authn_id)
     692            0 :         set_authn_id(port, ret->authn_id);
     693              : 
     694            0 :     if (!ret->authorized)
     695              :     {
     696            0 :         if (ret->error_detail)
     697            0 :             *logdetail = ret->error_detail;
     698              :         else
     699            0 :             *logdetail = _("Validator failed to authorize the provided token.");
     700              : 
     701            0 :         status = false;
     702            0 :         goto cleanup;
     703              :     }
     704              : 
     705            0 :     if (port->hba->oauth_skip_usermap)
     706              :     {
     707              :         /*
     708              :          * If the validator is our authorization authority, we're done.
     709              :          * Authentication may or may not have been performed depending on the
     710              :          * validator implementation; all that matters is that the validator
     711              :          * says the user can log in with the target role.
     712              :          */
     713            0 :         status = true;
     714            0 :         goto cleanup;
     715              :     }
     716              : 
     717              :     /* Make sure the validator authenticated the user. */
     718            0 :     if (ret->authn_id == NULL || ret->authn_id[0] == '\0')
     719              :     {
     720            0 :         *logdetail = _("Validator provided no identity.");
     721              : 
     722            0 :         status = false;
     723            0 :         goto cleanup;
     724              :     }
     725              : 
     726              :     /* Finally, check the user map. */
     727            0 :     map_status = check_usermap(port->hba->usermap, port->user_name,
     728              :                                MyClientConnectionInfo.authn_id, false);
     729            0 :     status = (map_status == STATUS_OK);
     730              : 
     731            0 : cleanup:
     732              : 
     733              :     /*
     734              :      * Clear and free the validation result from the validator module once
     735              :      * we're done with it.
     736              :      */
     737            0 :     if (ret->authn_id != NULL)
     738            0 :         pfree(ret->authn_id);
     739            0 :     pfree(ret);
     740              : 
     741            0 :     return status;
     742              : }
     743              : 
     744              : /*
     745              :  * load_validator_library
     746              :  *
     747              :  * Load the configured validator library in order to perform token validation.
     748              :  * There is no built-in fallback since validation is implementation specific. If
     749              :  * no validator library is configured, or if it fails to load, then error out
     750              :  * since token validation won't be possible.
     751              :  */
     752              : static void
     753            0 : load_validator_library(const char *libname)
     754              : {
     755              :     OAuthValidatorModuleInit validator_init;
     756              :     MemoryContextCallback *mcb;
     757              : 
     758              :     /*
     759              :      * The presence, and validity, of libname has already been established by
     760              :      * check_oauth_validator so we don't need to perform more than Assert
     761              :      * level checking here.
     762              :      */
     763              :     Assert(libname && *libname);
     764              : 
     765            0 :     validator_init = (OAuthValidatorModuleInit)
     766            0 :         load_external_function(libname, "_PG_oauth_validator_module_init",
     767              :                                false, NULL);
     768              : 
     769              :     /*
     770              :      * The validator init function is required since it will set the callbacks
     771              :      * for the validator library.
     772              :      */
     773            0 :     if (validator_init == NULL)
     774            0 :         ereport(ERROR,
     775              :                 errmsg("%s module \"%s\" must define the symbol %s",
     776              :                        "OAuth validator", libname, "_PG_oauth_validator_module_init"));
     777              : 
     778            0 :     ValidatorCallbacks = (*validator_init) ();
     779              :     Assert(ValidatorCallbacks);
     780              : 
     781              :     /*
     782              :      * Check the magic number, to protect against break-glass scenarios where
     783              :      * the ABI must change within a major version. load_external_function()
     784              :      * already checks for compatibility across major versions.
     785              :      */
     786            0 :     if (ValidatorCallbacks->magic != PG_OAUTH_VALIDATOR_MAGIC)
     787            0 :         ereport(ERROR,
     788              :                 errmsg("%s module \"%s\": magic number mismatch",
     789              :                        "OAuth validator", libname),
     790              :                 errdetail("Server has magic number 0x%08X, module has 0x%08X.",
     791              :                           PG_OAUTH_VALIDATOR_MAGIC, ValidatorCallbacks->magic));
     792              : 
     793              :     /*
     794              :      * Make sure all required callbacks are present in the ValidatorCallbacks
     795              :      * structure. Right now only the validation callback is required.
     796              :      */
     797            0 :     if (ValidatorCallbacks->validate_cb == NULL)
     798            0 :         ereport(ERROR,
     799              :                 errmsg("%s module \"%s\" must provide a %s callback",
     800              :                        "OAuth validator", libname, "validate_cb"));
     801              : 
     802              :     /* Allocate memory for validator library private state data */
     803            0 :     validator_module_state = palloc0_object(ValidatorModuleState);
     804            0 :     validator_module_state->sversion = PG_VERSION_NUM;
     805              : 
     806            0 :     if (ValidatorCallbacks->startup_cb != NULL)
     807            0 :         ValidatorCallbacks->startup_cb(validator_module_state);
     808              : 
     809              :     /* Shut down the library before cleaning up its state. */
     810            0 :     mcb = palloc0_object(MemoryContextCallback);
     811            0 :     mcb->func = shutdown_validator_library;
     812              : 
     813            0 :     MemoryContextRegisterResetCallback(CurrentMemoryContext, mcb);
     814            0 : }
     815              : 
     816              : /*
     817              :  * Call the validator module's shutdown callback, if one is provided. This is
     818              :  * invoked during memory context reset.
     819              :  */
     820              : static void
     821            0 : shutdown_validator_library(void *arg)
     822              : {
     823            0 :     if (ValidatorCallbacks->shutdown_cb != NULL)
     824            0 :         ValidatorCallbacks->shutdown_cb(validator_module_state);
     825            0 : }
     826              : 
     827              : /*
     828              :  * Ensure an OAuth validator named in the HBA is permitted by the configuration.
     829              :  *
     830              :  * If the validator is currently unset and exactly one library is declared in
     831              :  * oauth_validator_libraries, then that library will be used as the validator.
     832              :  * Otherwise the name must be present in the list of oauth_validator_libraries.
     833              :  */
     834              : bool
     835            0 : check_oauth_validator(HbaLine *hbaline, int elevel, char **err_msg)
     836              : {
     837            0 :     int         line_num = hbaline->linenumber;
     838            0 :     const char *file_name = hbaline->sourcefile;
     839              :     char       *rawstring;
     840            0 :     List       *elemlist = NIL;
     841              : 
     842            0 :     *err_msg = NULL;
     843              : 
     844            0 :     if (oauth_validator_libraries_string[0] == '\0')
     845              :     {
     846            0 :         ereport(elevel,
     847              :                 errcode(ERRCODE_CONFIG_FILE_ERROR),
     848              :                 errmsg("oauth_validator_libraries must be set for authentication method %s",
     849              :                        "oauth"),
     850              :                 errcontext("line %d of configuration file \"%s\"",
     851              :                            line_num, file_name));
     852            0 :         *err_msg = psprintf("oauth_validator_libraries must be set for authentication method %s",
     853              :                             "oauth");
     854            0 :         return false;
     855              :     }
     856              : 
     857              :     /* SplitDirectoriesString needs a modifiable copy */
     858            0 :     rawstring = pstrdup(oauth_validator_libraries_string);
     859              : 
     860            0 :     if (!SplitDirectoriesString(rawstring, ',', &elemlist))
     861              :     {
     862              :         /* syntax error in list */
     863            0 :         ereport(elevel,
     864              :                 errcode(ERRCODE_CONFIG_FILE_ERROR),
     865              :                 errmsg("invalid list syntax in parameter \"%s\"",
     866              :                        "oauth_validator_libraries"));
     867            0 :         *err_msg = psprintf("invalid list syntax in parameter \"%s\"",
     868              :                             "oauth_validator_libraries");
     869            0 :         goto done;
     870              :     }
     871              : 
     872            0 :     if (!hbaline->oauth_validator)
     873              :     {
     874            0 :         if (elemlist->length == 1)
     875              :         {
     876            0 :             hbaline->oauth_validator = pstrdup(linitial(elemlist));
     877            0 :             goto done;
     878              :         }
     879              : 
     880            0 :         ereport(elevel,
     881              :                 errcode(ERRCODE_CONFIG_FILE_ERROR),
     882              :                 errmsg("authentication method \"oauth\" requires argument \"validator\" to be set when oauth_validator_libraries contains multiple options"),
     883              :                 errcontext("line %d of configuration file \"%s\"",
     884              :                            line_num, file_name));
     885            0 :         *err_msg = "authentication method \"oauth\" requires argument \"validator\" to be set when oauth_validator_libraries contains multiple options";
     886            0 :         goto done;
     887              :     }
     888              : 
     889            0 :     foreach_ptr(char, allowed, elemlist)
     890              :     {
     891            0 :         if (strcmp(allowed, hbaline->oauth_validator) == 0)
     892            0 :             goto done;
     893              :     }
     894              : 
     895            0 :     ereport(elevel,
     896              :             errcode(ERRCODE_INVALID_PARAMETER_VALUE),
     897              :             errmsg("validator \"%s\" is not permitted by %s",
     898              :                    hbaline->oauth_validator, "oauth_validator_libraries"),
     899              :             errcontext("line %d of configuration file \"%s\"",
     900              :                        line_num, file_name));
     901            0 :     *err_msg = psprintf("validator \"%s\" is not permitted by %s",
     902              :                         hbaline->oauth_validator, "oauth_validator_libraries");
     903              : 
     904            0 : done:
     905            0 :     list_free_deep(elemlist);
     906            0 :     pfree(rawstring);
     907              : 
     908            0 :     return (*err_msg == NULL);
     909              : }
        

Generated by: LCOV version 2.0-1