LCOV - code coverage report
Current view: top level - src/backend/libpq - auth-oauth.c (source / functions) Coverage Total Hit
Test: PostgreSQL 19devel Lines: 0.0 % 233 0
Test Date: 2026-03-01 11:15:05 Functions: 0.0 % 12 0
Legend: Lines:     hit not hit

            Line data    Source code
       1              : /*-------------------------------------------------------------------------
       2              :  *
       3              :  * auth-oauth.c
       4              :  *    Server-side implementation of the SASL OAUTHBEARER mechanism.
       5              :  *
       6              :  * See the following RFC for more details:
       7              :  * - RFC 7628: https://datatracker.ietf.org/doc/html/rfc7628
       8              :  *
       9              :  * Portions Copyright (c) 1996-2026, PostgreSQL Global Development Group
      10              :  * Portions Copyright (c) 1994, Regents of the University of California
      11              :  *
      12              :  * src/backend/libpq/auth-oauth.c
      13              :  *
      14              :  *-------------------------------------------------------------------------
      15              :  */
      16              : #include "postgres.h"
      17              : 
      18              : #include <unistd.h>
      19              : #include <fcntl.h>
      20              : 
      21              : #include "common/oauth-common.h"
      22              : #include "fmgr.h"
      23              : #include "lib/stringinfo.h"
      24              : #include "libpq/auth.h"
      25              : #include "libpq/hba.h"
      26              : #include "libpq/oauth.h"
      27              : #include "libpq/sasl.h"
      28              : #include "storage/fd.h"
      29              : #include "storage/ipc.h"
      30              : #include "utils/json.h"
      31              : #include "utils/varlena.h"
      32              : 
      33              : /* GUC */
      34              : char       *oauth_validator_libraries_string = NULL;
      35              : 
      36              : static void oauth_get_mechanisms(Port *port, StringInfo buf);
      37              : static void *oauth_init(Port *port, const char *selected_mech, const char *shadow_pass);
      38              : static int  oauth_exchange(void *opaq, const char *input, int inputlen,
      39              :                            char **output, int *outputlen, const char **logdetail);
      40              : 
      41              : static void load_validator_library(const char *libname);
      42              : static void shutdown_validator_library(void *arg);
      43              : 
      44              : static ValidatorModuleState *validator_module_state;
      45              : static const OAuthValidatorCallbacks *ValidatorCallbacks;
      46              : 
      47              : /* Mechanism declaration */
      48              : const pg_be_sasl_mech pg_be_oauth_mech = {
      49              :     .get_mechanisms = oauth_get_mechanisms,
      50              :     .init = oauth_init,
      51              :     .exchange = oauth_exchange,
      52              : 
      53              :     .max_message_length = PG_MAX_AUTH_TOKEN_LENGTH,
      54              : };
      55              : 
      56              : /* Valid states for the oauth_exchange() machine. */
      57              : enum oauth_state
      58              : {
      59              :     OAUTH_STATE_INIT = 0,
      60              :     OAUTH_STATE_ERROR,
      61              :     OAUTH_STATE_FINISHED,
      62              : };
      63              : 
      64              : /* Mechanism callback state. */
      65              : struct oauth_ctx
      66              : {
      67              :     enum oauth_state state;
      68              :     Port       *port;
      69              :     const char *issuer;
      70              :     const char *scope;
      71              : };
      72              : 
      73              : static char *sanitize_char(char c);
      74              : static char *parse_kvpairs_for_auth(char **input);
      75              : static void generate_error_response(struct oauth_ctx *ctx, char **output, int *outputlen);
      76              : static bool validate(Port *port, const char *auth);
      77              : 
      78              : /* Constants seen in an OAUTHBEARER client initial response. */
      79              : #define KVSEP 0x01              /* separator byte for key/value pairs */
      80              : #define AUTH_KEY "auth"           /* key containing the Authorization header */
      81              : #define BEARER_SCHEME "Bearer " /* required header scheme (case-insensitive!) */
      82              : 
      83              : /*
      84              :  * Retrieves the OAUTHBEARER mechanism list (currently a single item).
      85              :  *
      86              :  * For a full description of the API, see libpq/sasl.h.
      87              :  */
      88              : static void
      89            0 : oauth_get_mechanisms(Port *port, StringInfo buf)
      90              : {
      91              :     /* Only OAUTHBEARER is supported. */
      92            0 :     appendStringInfoString(buf, OAUTHBEARER_NAME);
      93            0 :     appendStringInfoChar(buf, '\0');
      94            0 : }
      95              : 
      96              : /*
      97              :  * Initializes mechanism state and loads the configured validator module.
      98              :  *
      99              :  * For a full description of the API, see libpq/sasl.h.
     100              :  */
     101              : static void *
     102            0 : oauth_init(Port *port, const char *selected_mech, const char *shadow_pass)
     103              : {
     104              :     struct oauth_ctx *ctx;
     105              : 
     106            0 :     if (strcmp(selected_mech, OAUTHBEARER_NAME) != 0)
     107            0 :         ereport(ERROR,
     108              :                 errcode(ERRCODE_PROTOCOL_VIOLATION),
     109              :                 errmsg("client selected an invalid SASL authentication mechanism"));
     110              : 
     111            0 :     ctx = palloc0_object(struct oauth_ctx);
     112              : 
     113            0 :     ctx->state = OAUTH_STATE_INIT;
     114            0 :     ctx->port = port;
     115              : 
     116              :     Assert(port->hba);
     117            0 :     ctx->issuer = port->hba->oauth_issuer;
     118            0 :     ctx->scope = port->hba->oauth_scope;
     119              : 
     120            0 :     load_validator_library(port->hba->oauth_validator);
     121              : 
     122            0 :     return ctx;
     123              : }
     124              : 
     125              : /*
     126              :  * Implements the OAUTHBEARER SASL exchange (RFC 7628, Sec. 3.2). This pulls
     127              :  * apart the client initial response and validates the Bearer token. It also
     128              :  * handles the dummy error response for a failed handshake, as described in
     129              :  * Sec. 3.2.3.
     130              :  *
     131              :  * For a full description of the API, see libpq/sasl.h.
     132              :  */
     133              : static int
     134            0 : oauth_exchange(void *opaq, const char *input, int inputlen,
     135              :                char **output, int *outputlen, const char **logdetail)
     136              : {
     137              :     char       *input_copy;
     138              :     char       *p;
     139              :     char        cbind_flag;
     140              :     char       *auth;
     141              :     int         status;
     142              : 
     143            0 :     struct oauth_ctx *ctx = opaq;
     144              : 
     145            0 :     *output = NULL;
     146            0 :     *outputlen = -1;
     147              : 
     148              :     /*
     149              :      * If the client didn't include an "Initial Client Response" in the
     150              :      * SASLInitialResponse message, send an empty challenge, to which the
     151              :      * client will respond with the same data that usually comes in the
     152              :      * Initial Client Response.
     153              :      */
     154            0 :     if (input == NULL)
     155              :     {
     156              :         Assert(ctx->state == OAUTH_STATE_INIT);
     157              : 
     158            0 :         *output = pstrdup("");
     159            0 :         *outputlen = 0;
     160            0 :         return PG_SASL_EXCHANGE_CONTINUE;
     161              :     }
     162              : 
     163              :     /*
     164              :      * Check that the input length agrees with the string length of the input.
     165              :      */
     166            0 :     if (inputlen == 0)
     167            0 :         ereport(ERROR,
     168              :                 errcode(ERRCODE_PROTOCOL_VIOLATION),
     169              :                 errmsg("malformed OAUTHBEARER message"),
     170              :                 errdetail("The message is empty."));
     171            0 :     if (inputlen != strlen(input))
     172            0 :         ereport(ERROR,
     173              :                 errcode(ERRCODE_PROTOCOL_VIOLATION),
     174              :                 errmsg("malformed OAUTHBEARER message"),
     175              :                 errdetail("Message length does not match input length."));
     176              : 
     177            0 :     switch (ctx->state)
     178              :     {
     179            0 :         case OAUTH_STATE_INIT:
     180              :             /* Handle this case below. */
     181            0 :             break;
     182              : 
     183            0 :         case OAUTH_STATE_ERROR:
     184              : 
     185              :             /*
     186              :              * Only one response is valid for the client during authentication
     187              :              * failure: a single kvsep.
     188              :              */
     189            0 :             if (inputlen != 1 || *input != KVSEP)
     190            0 :                 ereport(ERROR,
     191              :                         errcode(ERRCODE_PROTOCOL_VIOLATION),
     192              :                         errmsg("malformed OAUTHBEARER message"),
     193              :                         errdetail("Client did not send a kvsep response."));
     194              : 
     195              :             /* The (failed) handshake is now complete. */
     196            0 :             ctx->state = OAUTH_STATE_FINISHED;
     197            0 :             return PG_SASL_EXCHANGE_FAILURE;
     198              : 
     199            0 :         default:
     200            0 :             elog(ERROR, "invalid OAUTHBEARER exchange state");
     201              :             return PG_SASL_EXCHANGE_FAILURE;
     202              :     }
     203              : 
     204              :     /* Handle the client's initial message. */
     205            0 :     p = input_copy = pstrdup(input);
     206              : 
     207              :     /*
     208              :      * OAUTHBEARER does not currently define a channel binding (so there is no
     209              :      * OAUTHBEARER-PLUS, and we do not accept a 'p' specifier). We accept a
     210              :      * 'y' specifier purely for the remote chance that a future specification
     211              :      * could define one; then future clients can still interoperate with this
     212              :      * server implementation. 'n' is the expected case.
     213              :      */
     214            0 :     cbind_flag = *p;
     215            0 :     switch (cbind_flag)
     216              :     {
     217            0 :         case 'p':
     218            0 :             ereport(ERROR,
     219              :                     errcode(ERRCODE_PROTOCOL_VIOLATION),
     220              :                     errmsg("malformed OAUTHBEARER message"),
     221              :                     errdetail("The server does not support channel binding for OAuth, but the client message includes channel binding data."));
     222              :             break;
     223              : 
     224            0 :         case 'y':               /* fall through */
     225              :         case 'n':
     226            0 :             p++;
     227            0 :             if (*p != ',')
     228            0 :                 ereport(ERROR,
     229              :                         errcode(ERRCODE_PROTOCOL_VIOLATION),
     230              :                         errmsg("malformed OAUTHBEARER message"),
     231              :                         errdetail("Comma expected, but found character \"%s\".",
     232              :                                   sanitize_char(*p)));
     233            0 :             p++;
     234            0 :             break;
     235              : 
     236            0 :         default:
     237            0 :             ereport(ERROR,
     238              :                     errcode(ERRCODE_PROTOCOL_VIOLATION),
     239              :                     errmsg("malformed OAUTHBEARER message"),
     240              :                     errdetail("Unexpected channel-binding flag \"%s\".",
     241              :                               sanitize_char(cbind_flag)));
     242              :     }
     243              : 
     244              :     /*
     245              :      * Forbid optional authzid (authorization identity).  We don't support it.
     246              :      */
     247            0 :     if (*p == 'a')
     248            0 :         ereport(ERROR,
     249              :                 errcode(ERRCODE_FEATURE_NOT_SUPPORTED),
     250              :                 errmsg("client uses authorization identity, but it is not supported"));
     251            0 :     if (*p != ',')
     252            0 :         ereport(ERROR,
     253              :                 errcode(ERRCODE_PROTOCOL_VIOLATION),
     254              :                 errmsg("malformed OAUTHBEARER message"),
     255              :                 errdetail("Unexpected attribute \"%s\" in client-first-message.",
     256              :                           sanitize_char(*p)));
     257            0 :     p++;
     258              : 
     259              :     /* All remaining fields are separated by the RFC's kvsep (\x01). */
     260            0 :     if (*p != KVSEP)
     261            0 :         ereport(ERROR,
     262              :                 errcode(ERRCODE_PROTOCOL_VIOLATION),
     263              :                 errmsg("malformed OAUTHBEARER message"),
     264              :                 errdetail("Key-value separator expected, but found character \"%s\".",
     265              :                           sanitize_char(*p)));
     266            0 :     p++;
     267              : 
     268            0 :     auth = parse_kvpairs_for_auth(&p);
     269            0 :     if (!auth)
     270            0 :         ereport(ERROR,
     271              :                 errcode(ERRCODE_PROTOCOL_VIOLATION),
     272              :                 errmsg("malformed OAUTHBEARER message"),
     273              :                 errdetail("Message does not contain an auth value."));
     274              : 
     275              :     /* We should be at the end of our message. */
     276            0 :     if (*p)
     277            0 :         ereport(ERROR,
     278              :                 errcode(ERRCODE_PROTOCOL_VIOLATION),
     279              :                 errmsg("malformed OAUTHBEARER message"),
     280              :                 errdetail("Message contains additional data after the final terminator."));
     281              : 
     282            0 :     if (!validate(ctx->port, auth))
     283              :     {
     284            0 :         generate_error_response(ctx, output, outputlen);
     285              : 
     286            0 :         ctx->state = OAUTH_STATE_ERROR;
     287            0 :         status = PG_SASL_EXCHANGE_CONTINUE;
     288              :     }
     289              :     else
     290              :     {
     291            0 :         ctx->state = OAUTH_STATE_FINISHED;
     292            0 :         status = PG_SASL_EXCHANGE_SUCCESS;
     293              :     }
     294              : 
     295              :     /* Don't let extra copies of the bearer token hang around. */
     296            0 :     explicit_bzero(input_copy, inputlen);
     297              : 
     298            0 :     return status;
     299              : }
     300              : 
     301              : /*
     302              :  * Convert an arbitrary byte to printable form.  For error messages.
     303              :  *
     304              :  * If it's a printable ASCII character, print it as a single character.
     305              :  * otherwise, print it in hex.
     306              :  *
     307              :  * The returned pointer points to a static buffer.
     308              :  */
     309              : static char *
     310            0 : sanitize_char(char c)
     311              : {
     312              :     static char buf[5];
     313              : 
     314            0 :     if (c >= 0x21 && c <= 0x7E)
     315            0 :         snprintf(buf, sizeof(buf), "'%c'", c);
     316              :     else
     317            0 :         snprintf(buf, sizeof(buf), "0x%02x", (unsigned char) c);
     318            0 :     return buf;
     319              : }
     320              : 
     321              : /*
     322              :  * Performs syntactic validation of a key and value from the initial client
     323              :  * response. (Semantic validation of interesting values must be performed
     324              :  * later.)
     325              :  */
     326              : static void
     327            0 : validate_kvpair(const char *key, const char *val)
     328              : {
     329              :     /*-----
     330              :      * From Sec 3.1:
     331              :      *     key            = 1*(ALPHA)
     332              :      */
     333              :     static const char *key_allowed_set =
     334              :         "abcdefghijklmnopqrstuvwxyz"
     335              :         "ABCDEFGHIJKLMNOPQRSTUVWXYZ";
     336              : 
     337              :     size_t      span;
     338              : 
     339            0 :     if (!key[0])
     340            0 :         ereport(ERROR,
     341              :                 errcode(ERRCODE_PROTOCOL_VIOLATION),
     342              :                 errmsg("malformed OAUTHBEARER message"),
     343              :                 errdetail("Message contains an empty key name."));
     344              : 
     345            0 :     span = strspn(key, key_allowed_set);
     346            0 :     if (key[span] != '\0')
     347            0 :         ereport(ERROR,
     348              :                 errcode(ERRCODE_PROTOCOL_VIOLATION),
     349              :                 errmsg("malformed OAUTHBEARER message"),
     350              :                 errdetail("Message contains an invalid key name."));
     351              : 
     352              :     /*-----
     353              :      * From Sec 3.1:
     354              :      *     value          = *(VCHAR / SP / HTAB / CR / LF )
     355              :      *
     356              :      * The VCHAR (visible character) class is large; a loop is more
     357              :      * straightforward than strspn().
     358              :      */
     359            0 :     for (; *val; ++val)
     360              :     {
     361            0 :         if (0x21 <= *val && *val <= 0x7E)
     362            0 :             continue;           /* VCHAR */
     363              : 
     364            0 :         switch (*val)
     365              :         {
     366            0 :             case ' ':
     367              :             case '\t':
     368              :             case '\r':
     369              :             case '\n':
     370            0 :                 continue;       /* SP, HTAB, CR, LF */
     371              : 
     372            0 :             default:
     373            0 :                 ereport(ERROR,
     374              :                         errcode(ERRCODE_PROTOCOL_VIOLATION),
     375              :                         errmsg("malformed OAUTHBEARER message"),
     376              :                         errdetail("Message contains an invalid value."));
     377              :         }
     378              :     }
     379            0 : }
     380              : 
     381              : /*
     382              :  * Consumes all kvpairs in an OAUTHBEARER exchange message. If the "auth" key is
     383              :  * found, its value is returned.
     384              :  */
     385              : static char *
     386            0 : parse_kvpairs_for_auth(char **input)
     387              : {
     388            0 :     char       *pos = *input;
     389            0 :     char       *auth = NULL;
     390              : 
     391              :     /*----
     392              :      * The relevant ABNF, from Sec. 3.1:
     393              :      *
     394              :      *     kvsep          = %x01
     395              :      *     key            = 1*(ALPHA)
     396              :      *     value          = *(VCHAR / SP / HTAB / CR / LF )
     397              :      *     kvpair         = key "=" value kvsep
     398              :      *   ;;gs2-header     = See RFC 5801
     399              :      *     client-resp    = (gs2-header kvsep *kvpair kvsep) / kvsep
     400              :      *
     401              :      * By the time we reach this code, the gs2-header and initial kvsep have
     402              :      * already been validated. We start at the beginning of the first kvpair.
     403              :      */
     404              : 
     405            0 :     while (*pos)
     406              :     {
     407              :         char       *end;
     408              :         char       *sep;
     409              :         char       *key;
     410              :         char       *value;
     411              : 
     412              :         /*
     413              :          * Find the end of this kvpair. Note that input is null-terminated by
     414              :          * the SASL code, so the strchr() is bounded.
     415              :          */
     416            0 :         end = strchr(pos, KVSEP);
     417            0 :         if (!end)
     418            0 :             ereport(ERROR,
     419              :                     errcode(ERRCODE_PROTOCOL_VIOLATION),
     420              :                     errmsg("malformed OAUTHBEARER message"),
     421              :                     errdetail("Message contains an unterminated key/value pair."));
     422            0 :         *end = '\0';
     423              : 
     424            0 :         if (pos == end)
     425              :         {
     426              :             /* Empty kvpair, signifying the end of the list. */
     427            0 :             *input = pos + 1;
     428            0 :             return auth;
     429              :         }
     430              : 
     431              :         /*
     432              :          * Find the end of the key name.
     433              :          */
     434            0 :         sep = strchr(pos, '=');
     435            0 :         if (!sep)
     436            0 :             ereport(ERROR,
     437              :                     errcode(ERRCODE_PROTOCOL_VIOLATION),
     438              :                     errmsg("malformed OAUTHBEARER message"),
     439              :                     errdetail("Message contains a key without a value."));
     440            0 :         *sep = '\0';
     441              : 
     442              :         /* Both key and value are now safely terminated. */
     443            0 :         key = pos;
     444            0 :         value = sep + 1;
     445            0 :         validate_kvpair(key, value);
     446              : 
     447            0 :         if (strcmp(key, AUTH_KEY) == 0)
     448              :         {
     449            0 :             if (auth)
     450            0 :                 ereport(ERROR,
     451              :                         errcode(ERRCODE_PROTOCOL_VIOLATION),
     452              :                         errmsg("malformed OAUTHBEARER message"),
     453              :                         errdetail("Message contains multiple auth values."));
     454              : 
     455            0 :             auth = value;
     456              :         }
     457              :         else
     458              :         {
     459              :             /*
     460              :              * The RFC also defines the host and port keys, but they are not
     461              :              * required for OAUTHBEARER and we do not use them. Also, per Sec.
     462              :              * 3.1, any key/value pairs we don't recognize must be ignored.
     463              :              */
     464              :         }
     465              : 
     466              :         /* Move to the next pair. */
     467            0 :         pos = end + 1;
     468              :     }
     469              : 
     470            0 :     ereport(ERROR,
     471              :             errcode(ERRCODE_PROTOCOL_VIOLATION),
     472              :             errmsg("malformed OAUTHBEARER message"),
     473              :             errdetail("Message did not contain a final terminator."));
     474              : 
     475              :     pg_unreachable();
     476              :     return NULL;
     477              : }
     478              : 
     479              : /*
     480              :  * Builds the JSON response for failed authentication (RFC 7628, Sec. 3.2.2).
     481              :  * This contains the required scopes for entry and a pointer to the OAuth/OpenID
     482              :  * discovery document, which the client may use to conduct its OAuth flow.
     483              :  */
     484              : static void
     485            0 : generate_error_response(struct oauth_ctx *ctx, char **output, int *outputlen)
     486              : {
     487              :     StringInfoData buf;
     488              :     StringInfoData issuer;
     489              : 
     490              :     /*
     491              :      * The admin needs to set an issuer and scope for OAuth to work. There's
     492              :      * not really a way to hide this from the user, either, because we can't
     493              :      * choose a "default" issuer, so be honest in the failure message. (In
     494              :      * practice such configurations are rejected during HBA parsing.)
     495              :      */
     496            0 :     if (!ctx->issuer || !ctx->scope)
     497            0 :         ereport(FATAL,
     498              :                 errcode(ERRCODE_INTERNAL_ERROR),
     499              :                 errmsg("OAuth is not properly configured for this user"),
     500              :                 errdetail_log("The issuer and scope parameters must be set in pg_hba.conf."));
     501              : 
     502              :     /*
     503              :      * Build a default .well-known URI based on our issuer, unless the HBA has
     504              :      * already provided one.
     505              :      */
     506            0 :     initStringInfo(&issuer);
     507            0 :     appendStringInfoString(&issuer, ctx->issuer);
     508            0 :     if (strstr(ctx->issuer, "/.well-known/") == NULL)
     509            0 :         appendStringInfoString(&issuer, "/.well-known/openid-configuration");
     510              : 
     511            0 :     initStringInfo(&buf);
     512              : 
     513              :     /*
     514              :      * Escaping the string here is belt-and-suspenders defensive programming
     515              :      * since escapable characters aren't valid in either the issuer URI or the
     516              :      * scope list, but the HBA doesn't enforce that yet.
     517              :      */
     518            0 :     appendStringInfoString(&buf, "{ \"status\": \"invalid_token\", ");
     519              : 
     520            0 :     appendStringInfoString(&buf, "\"openid-configuration\": ");
     521            0 :     escape_json(&buf, issuer.data);
     522            0 :     pfree(issuer.data);
     523              : 
     524            0 :     appendStringInfoString(&buf, ", \"scope\": ");
     525            0 :     escape_json(&buf, ctx->scope);
     526              : 
     527            0 :     appendStringInfoString(&buf, " }");
     528              : 
     529            0 :     *output = buf.data;
     530            0 :     *outputlen = buf.len;
     531            0 : }
     532              : 
     533              : /*-----
     534              :  * Validates the provided Authorization header and returns the token from
     535              :  * within it. NULL is returned on validation failure.
     536              :  *
     537              :  * Only Bearer tokens are accepted. The ABNF is defined in RFC 6750, Sec.
     538              :  * 2.1:
     539              :  *
     540              :  *      b64token    = 1*( ALPHA / DIGIT /
     541              :  *                        "-" / "." / "_" / "~" / "+" / "/" ) *"="
     542              :  *      credentials = "Bearer" 1*SP b64token
     543              :  *
     544              :  * The "credentials" construction is what we receive in our auth value.
     545              :  *
     546              :  * Since that spec is subordinate to HTTP (i.e. the HTTP Authorization
     547              :  * header format; RFC 9110 Sec. 11), the "Bearer" scheme string must be
     548              :  * compared case-insensitively. (This is not mentioned in RFC 6750, but the
     549              :  * OAUTHBEARER spec points it out: RFC 7628 Sec. 4.)
     550              :  *
     551              :  * Invalid formats are technically a protocol violation, but we shouldn't
     552              :  * reflect any information about the sensitive Bearer token back to the
     553              :  * client; log at COMMERROR instead.
     554              :  */
     555              : static const char *
     556            0 : validate_token_format(const char *header)
     557              : {
     558              :     size_t      span;
     559              :     const char *token;
     560              :     static const char *const b64token_allowed_set =
     561              :         "abcdefghijklmnopqrstuvwxyz"
     562              :         "ABCDEFGHIJKLMNOPQRSTUVWXYZ"
     563              :         "0123456789-._~+/";
     564              : 
     565              :     /* Missing auth headers should be handled by the caller. */
     566              :     Assert(header);
     567              : 
     568            0 :     if (header[0] == '\0')
     569              :     {
     570              :         /*
     571              :          * A completely empty auth header represents a query for
     572              :          * authentication parameters. The client expects it to fail; there's
     573              :          * no need to make any extra noise in the logs.
     574              :          *
     575              :          * TODO: should we find a way to return STATUS_EOF at the top level,
     576              :          * to suppress the authentication error entirely?
     577              :          */
     578            0 :         return NULL;
     579              :     }
     580              : 
     581            0 :     if (pg_strncasecmp(header, BEARER_SCHEME, strlen(BEARER_SCHEME)))
     582              :     {
     583            0 :         ereport(COMMERROR,
     584              :                 errcode(ERRCODE_PROTOCOL_VIOLATION),
     585              :                 errmsg("malformed OAuth bearer token"),
     586              :                 errdetail_log("Client response indicated a non-Bearer authentication scheme."));
     587            0 :         return NULL;
     588              :     }
     589              : 
     590              :     /* Pull the bearer token out of the auth value. */
     591            0 :     token = header + strlen(BEARER_SCHEME);
     592              : 
     593              :     /* Swallow any additional spaces. */
     594            0 :     while (*token == ' ')
     595            0 :         token++;
     596              : 
     597              :     /* Tokens must not be empty. */
     598            0 :     if (!*token)
     599              :     {
     600            0 :         ereport(COMMERROR,
     601              :                 errcode(ERRCODE_PROTOCOL_VIOLATION),
     602              :                 errmsg("malformed OAuth bearer token"),
     603              :                 errdetail_log("Bearer token is empty."));
     604            0 :         return NULL;
     605              :     }
     606              : 
     607              :     /*
     608              :      * Make sure the token contains only allowed characters. Tokens may end
     609              :      * with any number of '=' characters.
     610              :      */
     611            0 :     span = strspn(token, b64token_allowed_set);
     612            0 :     while (token[span] == '=')
     613            0 :         span++;
     614              : 
     615            0 :     if (token[span] != '\0')
     616              :     {
     617              :         /*
     618              :          * This error message could be more helpful by printing the
     619              :          * problematic character(s), but that'd be a bit like printing a piece
     620              :          * of someone's password into the logs.
     621              :          */
     622            0 :         ereport(COMMERROR,
     623              :                 errcode(ERRCODE_PROTOCOL_VIOLATION),
     624              :                 errmsg("malformed OAuth bearer token"),
     625              :                 errdetail_log("Bearer token is not in the correct format."));
     626            0 :         return NULL;
     627              :     }
     628              : 
     629            0 :     return token;
     630              : }
     631              : 
     632              : /*
     633              :  * Checks that the "auth" kvpair in the client response contains a syntactically
     634              :  * valid Bearer token, then passes it along to the loaded validator module for
     635              :  * authorization. Returns true if validation succeeds.
     636              :  */
     637              : static bool
     638            0 : validate(Port *port, const char *auth)
     639              : {
     640              :     int         map_status;
     641              :     ValidatorModuleResult *ret;
     642              :     const char *token;
     643              :     bool        status;
     644              : 
     645              :     /* Ensure that we have a correct token to validate */
     646            0 :     if (!(token = validate_token_format(auth)))
     647            0 :         return false;
     648              : 
     649              :     /*
     650              :      * Ensure that we have a validation library loaded, this should always be
     651              :      * the case and an error here is indicative of a bug.
     652              :      */
     653            0 :     if (!ValidatorCallbacks || !ValidatorCallbacks->validate_cb)
     654            0 :         ereport(FATAL,
     655              :                 errcode(ERRCODE_INTERNAL_ERROR),
     656              :                 errmsg("validation of OAuth token requested without a validator loaded"));
     657              : 
     658              :     /* Call the validation function from the validator module */
     659            0 :     ret = palloc0_object(ValidatorModuleResult);
     660            0 :     if (!ValidatorCallbacks->validate_cb(validator_module_state, token,
     661            0 :                                          port->user_name, ret))
     662              :     {
     663            0 :         ereport(WARNING,
     664              :                 errcode(ERRCODE_INTERNAL_ERROR),
     665              :                 errmsg("internal error in OAuth validator module"));
     666            0 :         return false;
     667              :     }
     668              : 
     669              :     /*
     670              :      * Log any authentication results even if the token isn't authorized; it
     671              :      * might be useful for auditing or troubleshooting.
     672              :      */
     673            0 :     if (ret->authn_id)
     674            0 :         set_authn_id(port, ret->authn_id);
     675              : 
     676            0 :     if (!ret->authorized)
     677              :     {
     678            0 :         ereport(LOG,
     679              :                 errmsg("OAuth bearer authentication failed for user \"%s\"",
     680              :                        port->user_name),
     681              :                 errdetail_log("Validator failed to authorize the provided token."));
     682              : 
     683            0 :         status = false;
     684            0 :         goto cleanup;
     685              :     }
     686              : 
     687            0 :     if (port->hba->oauth_skip_usermap)
     688              :     {
     689              :         /*
     690              :          * If the validator is our authorization authority, we're done.
     691              :          * Authentication may or may not have been performed depending on the
     692              :          * validator implementation; all that matters is that the validator
     693              :          * says the user can log in with the target role.
     694              :          */
     695            0 :         status = true;
     696            0 :         goto cleanup;
     697              :     }
     698              : 
     699              :     /* Make sure the validator authenticated the user. */
     700            0 :     if (ret->authn_id == NULL || ret->authn_id[0] == '\0')
     701              :     {
     702            0 :         ereport(LOG,
     703              :                 errmsg("OAuth bearer authentication failed for user \"%s\"",
     704              :                        port->user_name),
     705              :                 errdetail_log("Validator provided no identity."));
     706              : 
     707            0 :         status = false;
     708            0 :         goto cleanup;
     709              :     }
     710              : 
     711              :     /* Finally, check the user map. */
     712            0 :     map_status = check_usermap(port->hba->usermap, port->user_name,
     713              :                                MyClientConnectionInfo.authn_id, false);
     714            0 :     status = (map_status == STATUS_OK);
     715              : 
     716            0 : cleanup:
     717              : 
     718              :     /*
     719              :      * Clear and free the validation result from the validator module once
     720              :      * we're done with it.
     721              :      */
     722            0 :     if (ret->authn_id != NULL)
     723            0 :         pfree(ret->authn_id);
     724            0 :     pfree(ret);
     725              : 
     726            0 :     return status;
     727              : }
     728              : 
     729              : /*
     730              :  * load_validator_library
     731              :  *
     732              :  * Load the configured validator library in order to perform token validation.
     733              :  * There is no built-in fallback since validation is implementation specific. If
     734              :  * no validator library is configured, or if it fails to load, then error out
     735              :  * since token validation won't be possible.
     736              :  */
     737              : static void
     738            0 : load_validator_library(const char *libname)
     739              : {
     740              :     OAuthValidatorModuleInit validator_init;
     741              :     MemoryContextCallback *mcb;
     742              : 
     743              :     /*
     744              :      * The presence, and validity, of libname has already been established by
     745              :      * check_oauth_validator so we don't need to perform more than Assert
     746              :      * level checking here.
     747              :      */
     748              :     Assert(libname && *libname);
     749              : 
     750            0 :     validator_init = (OAuthValidatorModuleInit)
     751            0 :         load_external_function(libname, "_PG_oauth_validator_module_init",
     752              :                                false, NULL);
     753              : 
     754              :     /*
     755              :      * The validator init function is required since it will set the callbacks
     756              :      * for the validator library.
     757              :      */
     758            0 :     if (validator_init == NULL)
     759            0 :         ereport(ERROR,
     760              :                 errmsg("%s module \"%s\" must define the symbol %s",
     761              :                        "OAuth validator", libname, "_PG_oauth_validator_module_init"));
     762              : 
     763            0 :     ValidatorCallbacks = (*validator_init) ();
     764              :     Assert(ValidatorCallbacks);
     765              : 
     766              :     /*
     767              :      * Check the magic number, to protect against break-glass scenarios where
     768              :      * the ABI must change within a major version. load_external_function()
     769              :      * already checks for compatibility across major versions.
     770              :      */
     771            0 :     if (ValidatorCallbacks->magic != PG_OAUTH_VALIDATOR_MAGIC)
     772            0 :         ereport(ERROR,
     773              :                 errmsg("%s module \"%s\": magic number mismatch",
     774              :                        "OAuth validator", libname),
     775              :                 errdetail("Server has magic number 0x%08X, module has 0x%08X.",
     776              :                           PG_OAUTH_VALIDATOR_MAGIC, ValidatorCallbacks->magic));
     777              : 
     778              :     /*
     779              :      * Make sure all required callbacks are present in the ValidatorCallbacks
     780              :      * structure. Right now only the validation callback is required.
     781              :      */
     782            0 :     if (ValidatorCallbacks->validate_cb == NULL)
     783            0 :         ereport(ERROR,
     784              :                 errmsg("%s module \"%s\" must provide a %s callback",
     785              :                        "OAuth validator", libname, "validate_cb"));
     786              : 
     787              :     /* Allocate memory for validator library private state data */
     788            0 :     validator_module_state = palloc0_object(ValidatorModuleState);
     789            0 :     validator_module_state->sversion = PG_VERSION_NUM;
     790              : 
     791            0 :     if (ValidatorCallbacks->startup_cb != NULL)
     792            0 :         ValidatorCallbacks->startup_cb(validator_module_state);
     793              : 
     794              :     /* Shut down the library before cleaning up its state. */
     795            0 :     mcb = palloc0_object(MemoryContextCallback);
     796            0 :     mcb->func = shutdown_validator_library;
     797              : 
     798            0 :     MemoryContextRegisterResetCallback(CurrentMemoryContext, mcb);
     799            0 : }
     800              : 
     801              : /*
     802              :  * Call the validator module's shutdown callback, if one is provided. This is
     803              :  * invoked during memory context reset.
     804              :  */
     805              : static void
     806            0 : shutdown_validator_library(void *arg)
     807              : {
     808            0 :     if (ValidatorCallbacks->shutdown_cb != NULL)
     809            0 :         ValidatorCallbacks->shutdown_cb(validator_module_state);
     810            0 : }
     811              : 
     812              : /*
     813              :  * Ensure an OAuth validator named in the HBA is permitted by the configuration.
     814              :  *
     815              :  * If the validator is currently unset and exactly one library is declared in
     816              :  * oauth_validator_libraries, then that library will be used as the validator.
     817              :  * Otherwise the name must be present in the list of oauth_validator_libraries.
     818              :  */
     819              : bool
     820            0 : check_oauth_validator(HbaLine *hbaline, int elevel, char **err_msg)
     821              : {
     822            0 :     int         line_num = hbaline->linenumber;
     823            0 :     const char *file_name = hbaline->sourcefile;
     824              :     char       *rawstring;
     825            0 :     List       *elemlist = NIL;
     826              : 
     827            0 :     *err_msg = NULL;
     828              : 
     829            0 :     if (oauth_validator_libraries_string[0] == '\0')
     830              :     {
     831            0 :         ereport(elevel,
     832              :                 errcode(ERRCODE_CONFIG_FILE_ERROR),
     833              :                 errmsg("oauth_validator_libraries must be set for authentication method %s",
     834              :                        "oauth"),
     835              :                 errcontext("line %d of configuration file \"%s\"",
     836              :                            line_num, file_name));
     837            0 :         *err_msg = psprintf("oauth_validator_libraries must be set for authentication method %s",
     838              :                             "oauth");
     839            0 :         return false;
     840              :     }
     841              : 
     842              :     /* SplitDirectoriesString needs a modifiable copy */
     843            0 :     rawstring = pstrdup(oauth_validator_libraries_string);
     844              : 
     845            0 :     if (!SplitDirectoriesString(rawstring, ',', &elemlist))
     846              :     {
     847              :         /* syntax error in list */
     848            0 :         ereport(elevel,
     849              :                 errcode(ERRCODE_CONFIG_FILE_ERROR),
     850              :                 errmsg("invalid list syntax in parameter \"%s\"",
     851              :                        "oauth_validator_libraries"));
     852            0 :         *err_msg = psprintf("invalid list syntax in parameter \"%s\"",
     853              :                             "oauth_validator_libraries");
     854            0 :         goto done;
     855              :     }
     856              : 
     857            0 :     if (!hbaline->oauth_validator)
     858              :     {
     859            0 :         if (elemlist->length == 1)
     860              :         {
     861            0 :             hbaline->oauth_validator = pstrdup(linitial(elemlist));
     862            0 :             goto done;
     863              :         }
     864              : 
     865            0 :         ereport(elevel,
     866              :                 errcode(ERRCODE_CONFIG_FILE_ERROR),
     867              :                 errmsg("authentication method \"oauth\" requires argument \"validator\" to be set when oauth_validator_libraries contains multiple options"),
     868              :                 errcontext("line %d of configuration file \"%s\"",
     869              :                            line_num, file_name));
     870            0 :         *err_msg = "authentication method \"oauth\" requires argument \"validator\" to be set when oauth_validator_libraries contains multiple options";
     871            0 :         goto done;
     872              :     }
     873              : 
     874            0 :     foreach_ptr(char, allowed, elemlist)
     875              :     {
     876            0 :         if (strcmp(allowed, hbaline->oauth_validator) == 0)
     877            0 :             goto done;
     878              :     }
     879              : 
     880            0 :     ereport(elevel,
     881              :             errcode(ERRCODE_INVALID_PARAMETER_VALUE),
     882              :             errmsg("validator \"%s\" is not permitted by %s",
     883              :                    hbaline->oauth_validator, "oauth_validator_libraries"),
     884              :             errcontext("line %d of configuration file \"%s\"",
     885              :                        line_num, file_name));
     886            0 :     *err_msg = psprintf("validator \"%s\" is not permitted by %s",
     887              :                         hbaline->oauth_validator, "oauth_validator_libraries");
     888              : 
     889            0 : done:
     890            0 :     list_free_deep(elemlist);
     891            0 :     pfree(rawstring);
     892              : 
     893            0 :     return (*err_msg == NULL);
     894              : }
        

Generated by: LCOV version 2.0-1