Line data Source code
1 : /*-------------------------------------------------------------------------
2 : *
3 : * auth-oauth.c
4 : * Server-side implementation of the SASL OAUTHBEARER mechanism.
5 : *
6 : * See the following RFC for more details:
7 : * - RFC 7628: https://datatracker.ietf.org/doc/html/rfc7628
8 : *
9 : * Portions Copyright (c) 1996-2026, PostgreSQL Global Development Group
10 : * Portions Copyright (c) 1994, Regents of the University of California
11 : *
12 : * src/backend/libpq/auth-oauth.c
13 : *
14 : *-------------------------------------------------------------------------
15 : */
16 : #include "postgres.h"
17 :
18 : #include <unistd.h>
19 : #include <fcntl.h>
20 :
21 : #include "common/oauth-common.h"
22 : #include "fmgr.h"
23 : #include "lib/stringinfo.h"
24 : #include "libpq/auth.h"
25 : #include "libpq/hba.h"
26 : #include "libpq/oauth.h"
27 : #include "libpq/sasl.h"
28 : #include "storage/fd.h"
29 : #include "storage/ipc.h"
30 : #include "utils/json.h"
31 : #include "utils/varlena.h"
32 :
33 : /* GUC */
34 : char *oauth_validator_libraries_string = NULL;
35 :
36 : static void oauth_get_mechanisms(Port *port, StringInfo buf);
37 : static void *oauth_init(Port *port, const char *selected_mech, const char *shadow_pass);
38 : static int oauth_exchange(void *opaq, const char *input, int inputlen,
39 : char **output, int *outputlen, const char **logdetail);
40 :
41 : static void load_validator_library(const char *libname);
42 : static void shutdown_validator_library(void *arg);
43 :
44 : static ValidatorModuleState *validator_module_state;
45 : static const OAuthValidatorCallbacks *ValidatorCallbacks;
46 :
47 : /* Mechanism declaration */
48 : const pg_be_sasl_mech pg_be_oauth_mech = {
49 : .get_mechanisms = oauth_get_mechanisms,
50 : .init = oauth_init,
51 : .exchange = oauth_exchange,
52 :
53 : .max_message_length = PG_MAX_AUTH_TOKEN_LENGTH,
54 : };
55 :
56 : /* Valid states for the oauth_exchange() machine. */
57 : enum oauth_state
58 : {
59 : OAUTH_STATE_INIT = 0,
60 : OAUTH_STATE_ERROR,
61 : OAUTH_STATE_ERROR_DISCOVERY,
62 : OAUTH_STATE_FINISHED,
63 : };
64 :
65 : /* Mechanism callback state. */
66 : struct oauth_ctx
67 : {
68 : enum oauth_state state;
69 : Port *port;
70 : const char *issuer;
71 : const char *scope;
72 : };
73 :
74 : static char *sanitize_char(char c);
75 : static char *parse_kvpairs_for_auth(char **input);
76 : static void generate_error_response(struct oauth_ctx *ctx, char **output, int *outputlen);
77 : static bool validate(Port *port, const char *auth, const char **logdetail);
78 :
79 : /* Constants seen in an OAUTHBEARER client initial response. */
80 : #define KVSEP 0x01 /* separator byte for key/value pairs */
81 : #define AUTH_KEY "auth" /* key containing the Authorization header */
82 : #define BEARER_SCHEME "Bearer " /* required header scheme (case-insensitive!) */
83 :
84 : /*
85 : * Retrieves the OAUTHBEARER mechanism list (currently a single item).
86 : *
87 : * For a full description of the API, see libpq/sasl.h.
88 : */
89 : static void
90 0 : oauth_get_mechanisms(Port *port, StringInfo buf)
91 : {
92 : /* Only OAUTHBEARER is supported. */
93 0 : appendStringInfoString(buf, OAUTHBEARER_NAME);
94 0 : appendStringInfoChar(buf, '\0');
95 0 : }
96 :
97 : /*
98 : * Initializes mechanism state and loads the configured validator module.
99 : *
100 : * For a full description of the API, see libpq/sasl.h.
101 : */
102 : static void *
103 0 : oauth_init(Port *port, const char *selected_mech, const char *shadow_pass)
104 : {
105 : struct oauth_ctx *ctx;
106 :
107 0 : if (strcmp(selected_mech, OAUTHBEARER_NAME) != 0)
108 0 : ereport(ERROR,
109 : errcode(ERRCODE_PROTOCOL_VIOLATION),
110 : errmsg("client selected an invalid SASL authentication mechanism"));
111 :
112 0 : ctx = palloc0_object(struct oauth_ctx);
113 :
114 0 : ctx->state = OAUTH_STATE_INIT;
115 0 : ctx->port = port;
116 :
117 : Assert(port->hba);
118 0 : ctx->issuer = port->hba->oauth_issuer;
119 0 : ctx->scope = port->hba->oauth_scope;
120 :
121 0 : load_validator_library(port->hba->oauth_validator);
122 :
123 0 : return ctx;
124 : }
125 :
126 : /*
127 : * Implements the OAUTHBEARER SASL exchange (RFC 7628, Sec. 3.2). This pulls
128 : * apart the client initial response and validates the Bearer token. It also
129 : * handles the dummy error response for a failed handshake, as described in
130 : * Sec. 3.2.3.
131 : *
132 : * For a full description of the API, see libpq/sasl.h.
133 : */
134 : static int
135 0 : oauth_exchange(void *opaq, const char *input, int inputlen,
136 : char **output, int *outputlen, const char **logdetail)
137 : {
138 : char *input_copy;
139 : char *p;
140 : char cbind_flag;
141 : char *auth;
142 : int status;
143 :
144 0 : struct oauth_ctx *ctx = opaq;
145 :
146 0 : *output = NULL;
147 0 : *outputlen = -1;
148 :
149 : /*
150 : * If the client didn't include an "Initial Client Response" in the
151 : * SASLInitialResponse message, send an empty challenge, to which the
152 : * client will respond with the same data that usually comes in the
153 : * Initial Client Response.
154 : */
155 0 : if (input == NULL)
156 : {
157 : Assert(ctx->state == OAUTH_STATE_INIT);
158 :
159 0 : *output = pstrdup("");
160 0 : *outputlen = 0;
161 0 : return PG_SASL_EXCHANGE_CONTINUE;
162 : }
163 :
164 : /*
165 : * Check that the input length agrees with the string length of the input.
166 : */
167 0 : if (inputlen == 0)
168 0 : ereport(ERROR,
169 : errcode(ERRCODE_PROTOCOL_VIOLATION),
170 : errmsg("malformed OAUTHBEARER message"),
171 : errdetail("The message is empty."));
172 0 : if (inputlen != strlen(input))
173 0 : ereport(ERROR,
174 : errcode(ERRCODE_PROTOCOL_VIOLATION),
175 : errmsg("malformed OAUTHBEARER message"),
176 : errdetail("Message length does not match input length."));
177 :
178 0 : switch (ctx->state)
179 : {
180 0 : case OAUTH_STATE_INIT:
181 : /* Handle this case below. */
182 0 : break;
183 :
184 0 : case OAUTH_STATE_ERROR:
185 : case OAUTH_STATE_ERROR_DISCOVERY:
186 :
187 : /*
188 : * Only one response is valid for the client during authentication
189 : * failure: a single kvsep.
190 : */
191 0 : if (inputlen != 1 || *input != KVSEP)
192 0 : ereport(ERROR,
193 : errcode(ERRCODE_PROTOCOL_VIOLATION),
194 : errmsg("malformed OAUTHBEARER message"),
195 : errdetail("Client did not send a kvsep response."));
196 :
197 : /*
198 : * The (failed) handshake is now complete. Don't report discovery
199 : * requests in the server log unless the log level is high enough.
200 : */
201 0 : if (ctx->state == OAUTH_STATE_ERROR_DISCOVERY)
202 : {
203 0 : ereport(DEBUG1, errmsg("OAuth issuer discovery requested"));
204 :
205 0 : ctx->state = OAUTH_STATE_FINISHED;
206 0 : return PG_SASL_EXCHANGE_ABANDONED;
207 : }
208 :
209 : /* We're not in discovery, so this is just a normal auth failure. */
210 0 : ctx->state = OAUTH_STATE_FINISHED;
211 0 : return PG_SASL_EXCHANGE_FAILURE;
212 :
213 0 : default:
214 0 : elog(ERROR, "invalid OAUTHBEARER exchange state");
215 : return PG_SASL_EXCHANGE_FAILURE;
216 : }
217 :
218 : /* Handle the client's initial message. */
219 0 : p = input_copy = pstrdup(input);
220 :
221 : /*
222 : * OAUTHBEARER does not currently define a channel binding (so there is no
223 : * OAUTHBEARER-PLUS, and we do not accept a 'p' specifier). We accept a
224 : * 'y' specifier purely for the remote chance that a future specification
225 : * could define one; then future clients can still interoperate with this
226 : * server implementation. 'n' is the expected case.
227 : */
228 0 : cbind_flag = *p;
229 0 : switch (cbind_flag)
230 : {
231 0 : case 'p':
232 0 : ereport(ERROR,
233 : errcode(ERRCODE_PROTOCOL_VIOLATION),
234 : errmsg("malformed OAUTHBEARER message"),
235 : errdetail("The server does not support channel binding for OAuth, but the client message includes channel binding data."));
236 : break;
237 :
238 0 : case 'y': /* fall through */
239 : case 'n':
240 0 : p++;
241 0 : if (*p != ',')
242 0 : ereport(ERROR,
243 : errcode(ERRCODE_PROTOCOL_VIOLATION),
244 : errmsg("malformed OAUTHBEARER message"),
245 : errdetail("Comma expected, but found character \"%s\".",
246 : sanitize_char(*p)));
247 0 : p++;
248 0 : break;
249 :
250 0 : default:
251 0 : ereport(ERROR,
252 : errcode(ERRCODE_PROTOCOL_VIOLATION),
253 : errmsg("malformed OAUTHBEARER message"),
254 : errdetail("Unexpected channel-binding flag \"%s\".",
255 : sanitize_char(cbind_flag)));
256 : }
257 :
258 : /*
259 : * Forbid optional authzid (authorization identity). We don't support it.
260 : */
261 0 : if (*p == 'a')
262 0 : ereport(ERROR,
263 : errcode(ERRCODE_FEATURE_NOT_SUPPORTED),
264 : errmsg("client uses authorization identity, but it is not supported"));
265 0 : if (*p != ',')
266 0 : ereport(ERROR,
267 : errcode(ERRCODE_PROTOCOL_VIOLATION),
268 : errmsg("malformed OAUTHBEARER message"),
269 : errdetail("Unexpected attribute \"%s\" in client-first-message.",
270 : sanitize_char(*p)));
271 0 : p++;
272 :
273 : /* All remaining fields are separated by the RFC's kvsep (\x01). */
274 0 : if (*p != KVSEP)
275 0 : ereport(ERROR,
276 : errcode(ERRCODE_PROTOCOL_VIOLATION),
277 : errmsg("malformed OAUTHBEARER message"),
278 : errdetail("Key-value separator expected, but found character \"%s\".",
279 : sanitize_char(*p)));
280 0 : p++;
281 :
282 0 : auth = parse_kvpairs_for_auth(&p);
283 0 : if (!auth)
284 0 : ereport(ERROR,
285 : errcode(ERRCODE_PROTOCOL_VIOLATION),
286 : errmsg("malformed OAUTHBEARER message"),
287 : errdetail("Message does not contain an auth value."));
288 :
289 : /* We should be at the end of our message. */
290 0 : if (*p)
291 0 : ereport(ERROR,
292 : errcode(ERRCODE_PROTOCOL_VIOLATION),
293 : errmsg("malformed OAUTHBEARER message"),
294 : errdetail("Message contains additional data after the final terminator."));
295 :
296 0 : if (auth[0] == '\0')
297 : {
298 : /*
299 : * An empty auth value represents a discovery request; the client
300 : * expects it to fail. Skip validation entirely and move directly to
301 : * the error response.
302 : */
303 0 : generate_error_response(ctx, output, outputlen);
304 :
305 0 : ctx->state = OAUTH_STATE_ERROR_DISCOVERY;
306 0 : status = PG_SASL_EXCHANGE_CONTINUE;
307 : }
308 0 : else if (!validate(ctx->port, auth, logdetail))
309 : {
310 0 : generate_error_response(ctx, output, outputlen);
311 :
312 0 : ctx->state = OAUTH_STATE_ERROR;
313 0 : status = PG_SASL_EXCHANGE_CONTINUE;
314 : }
315 : else
316 : {
317 0 : ctx->state = OAUTH_STATE_FINISHED;
318 0 : status = PG_SASL_EXCHANGE_SUCCESS;
319 : }
320 :
321 : /* Don't let extra copies of the bearer token hang around. */
322 0 : explicit_bzero(input_copy, inputlen);
323 :
324 0 : return status;
325 : }
326 :
327 : /*
328 : * Convert an arbitrary byte to printable form. For error messages.
329 : *
330 : * If it's a printable ASCII character, print it as a single character.
331 : * otherwise, print it in hex.
332 : *
333 : * The returned pointer points to a static buffer.
334 : */
335 : static char *
336 0 : sanitize_char(char c)
337 : {
338 : static char buf[5];
339 :
340 0 : if (c >= 0x21 && c <= 0x7E)
341 0 : snprintf(buf, sizeof(buf), "'%c'", c);
342 : else
343 0 : snprintf(buf, sizeof(buf), "0x%02x", (unsigned char) c);
344 0 : return buf;
345 : }
346 :
347 : /*
348 : * Performs syntactic validation of a key and value from the initial client
349 : * response. (Semantic validation of interesting values must be performed
350 : * later.)
351 : */
352 : static void
353 0 : validate_kvpair(const char *key, const char *val)
354 : {
355 : /*-----
356 : * From Sec 3.1:
357 : * key = 1*(ALPHA)
358 : */
359 : static const char *key_allowed_set =
360 : "abcdefghijklmnopqrstuvwxyz"
361 : "ABCDEFGHIJKLMNOPQRSTUVWXYZ";
362 :
363 : size_t span;
364 :
365 0 : if (!key[0])
366 0 : ereport(ERROR,
367 : errcode(ERRCODE_PROTOCOL_VIOLATION),
368 : errmsg("malformed OAUTHBEARER message"),
369 : errdetail("Message contains an empty key name."));
370 :
371 0 : span = strspn(key, key_allowed_set);
372 0 : if (key[span] != '\0')
373 0 : ereport(ERROR,
374 : errcode(ERRCODE_PROTOCOL_VIOLATION),
375 : errmsg("malformed OAUTHBEARER message"),
376 : errdetail("Message contains an invalid key name."));
377 :
378 : /*-----
379 : * From Sec 3.1:
380 : * value = *(VCHAR / SP / HTAB / CR / LF )
381 : *
382 : * The VCHAR (visible character) class is large; a loop is more
383 : * straightforward than strspn().
384 : */
385 0 : for (; *val; ++val)
386 : {
387 0 : if (0x21 <= *val && *val <= 0x7E)
388 0 : continue; /* VCHAR */
389 :
390 0 : switch (*val)
391 : {
392 0 : case ' ':
393 : case '\t':
394 : case '\r':
395 : case '\n':
396 0 : continue; /* SP, HTAB, CR, LF */
397 :
398 0 : default:
399 0 : ereport(ERROR,
400 : errcode(ERRCODE_PROTOCOL_VIOLATION),
401 : errmsg("malformed OAUTHBEARER message"),
402 : errdetail("Message contains an invalid value."));
403 : }
404 : }
405 0 : }
406 :
407 : /*
408 : * Consumes all kvpairs in an OAUTHBEARER exchange message. If the "auth" key is
409 : * found, its value is returned.
410 : */
411 : static char *
412 0 : parse_kvpairs_for_auth(char **input)
413 : {
414 0 : char *pos = *input;
415 0 : char *auth = NULL;
416 :
417 : /*----
418 : * The relevant ABNF, from Sec. 3.1:
419 : *
420 : * kvsep = %x01
421 : * key = 1*(ALPHA)
422 : * value = *(VCHAR / SP / HTAB / CR / LF )
423 : * kvpair = key "=" value kvsep
424 : * ;;gs2-header = See RFC 5801
425 : * client-resp = (gs2-header kvsep *kvpair kvsep) / kvsep
426 : *
427 : * By the time we reach this code, the gs2-header and initial kvsep have
428 : * already been validated. We start at the beginning of the first kvpair.
429 : */
430 :
431 0 : while (*pos)
432 : {
433 : char *end;
434 : char *sep;
435 : char *key;
436 : char *value;
437 :
438 : /*
439 : * Find the end of this kvpair. Note that input is null-terminated by
440 : * the SASL code, so the strchr() is bounded.
441 : */
442 0 : end = strchr(pos, KVSEP);
443 0 : if (!end)
444 0 : ereport(ERROR,
445 : errcode(ERRCODE_PROTOCOL_VIOLATION),
446 : errmsg("malformed OAUTHBEARER message"),
447 : errdetail("Message contains an unterminated key/value pair."));
448 0 : *end = '\0';
449 :
450 0 : if (pos == end)
451 : {
452 : /* Empty kvpair, signifying the end of the list. */
453 0 : *input = pos + 1;
454 0 : return auth;
455 : }
456 :
457 : /*
458 : * Find the end of the key name.
459 : */
460 0 : sep = strchr(pos, '=');
461 0 : if (!sep)
462 0 : ereport(ERROR,
463 : errcode(ERRCODE_PROTOCOL_VIOLATION),
464 : errmsg("malformed OAUTHBEARER message"),
465 : errdetail("Message contains a key without a value."));
466 0 : *sep = '\0';
467 :
468 : /* Both key and value are now safely terminated. */
469 0 : key = pos;
470 0 : value = sep + 1;
471 0 : validate_kvpair(key, value);
472 :
473 0 : if (strcmp(key, AUTH_KEY) == 0)
474 : {
475 0 : if (auth)
476 0 : ereport(ERROR,
477 : errcode(ERRCODE_PROTOCOL_VIOLATION),
478 : errmsg("malformed OAUTHBEARER message"),
479 : errdetail("Message contains multiple auth values."));
480 :
481 0 : auth = value;
482 : }
483 : else
484 : {
485 : /*
486 : * The RFC also defines the host and port keys, but they are not
487 : * required for OAUTHBEARER and we do not use them. Also, per Sec.
488 : * 3.1, any key/value pairs we don't recognize must be ignored.
489 : */
490 : }
491 :
492 : /* Move to the next pair. */
493 0 : pos = end + 1;
494 : }
495 :
496 0 : ereport(ERROR,
497 : errcode(ERRCODE_PROTOCOL_VIOLATION),
498 : errmsg("malformed OAUTHBEARER message"),
499 : errdetail("Message did not contain a final terminator."));
500 :
501 : pg_unreachable();
502 : return NULL;
503 : }
504 :
505 : /*
506 : * Builds the JSON response for failed authentication (RFC 7628, Sec. 3.2.2).
507 : * This contains the required scopes for entry and a pointer to the OAuth/OpenID
508 : * discovery document, which the client may use to conduct its OAuth flow.
509 : */
510 : static void
511 0 : generate_error_response(struct oauth_ctx *ctx, char **output, int *outputlen)
512 : {
513 : StringInfoData buf;
514 : StringInfoData issuer;
515 :
516 : /*
517 : * The admin needs to set an issuer and scope for OAuth to work. There's
518 : * not really a way to hide this from the user, either, because we can't
519 : * choose a "default" issuer, so be honest in the failure message. (In
520 : * practice such configurations are rejected during HBA parsing.)
521 : */
522 0 : if (!ctx->issuer || !ctx->scope)
523 0 : ereport(FATAL,
524 : errcode(ERRCODE_INTERNAL_ERROR),
525 : errmsg("OAuth is not properly configured for this user"),
526 : errdetail_log("The issuer and scope parameters must be set in pg_hba.conf."));
527 :
528 : /*
529 : * Build a default .well-known URI based on our issuer, unless the HBA has
530 : * already provided one.
531 : */
532 0 : initStringInfo(&issuer);
533 0 : appendStringInfoString(&issuer, ctx->issuer);
534 0 : if (strstr(ctx->issuer, "/.well-known/") == NULL)
535 0 : appendStringInfoString(&issuer, "/.well-known/openid-configuration");
536 :
537 0 : initStringInfo(&buf);
538 :
539 : /*
540 : * Escaping the string here is belt-and-suspenders defensive programming
541 : * since escapable characters aren't valid in either the issuer URI or the
542 : * scope list, but the HBA doesn't enforce that yet.
543 : */
544 0 : appendStringInfoString(&buf, "{ \"status\": \"invalid_token\", ");
545 :
546 0 : appendStringInfoString(&buf, "\"openid-configuration\": ");
547 0 : escape_json(&buf, issuer.data);
548 0 : pfree(issuer.data);
549 :
550 0 : appendStringInfoString(&buf, ", \"scope\": ");
551 0 : escape_json(&buf, ctx->scope);
552 :
553 0 : appendStringInfoString(&buf, " }");
554 :
555 0 : *output = buf.data;
556 0 : *outputlen = buf.len;
557 0 : }
558 :
559 : /*-----
560 : * Validates the provided Authorization header and returns the token from
561 : * within it. NULL is returned on validation failure.
562 : *
563 : * Only Bearer tokens are accepted. The ABNF is defined in RFC 6750, Sec.
564 : * 2.1:
565 : *
566 : * b64token = 1*( ALPHA / DIGIT /
567 : * "-" / "." / "_" / "~" / "+" / "/" ) *"="
568 : * credentials = "Bearer" 1*SP b64token
569 : *
570 : * The "credentials" construction is what we receive in our auth value.
571 : *
572 : * Since that spec is subordinate to HTTP (i.e. the HTTP Authorization
573 : * header format; RFC 9110 Sec. 11), the "Bearer" scheme string must be
574 : * compared case-insensitively. (This is not mentioned in RFC 6750, but the
575 : * OAUTHBEARER spec points it out: RFC 7628 Sec. 4.)
576 : *
577 : * Invalid formats are technically a protocol violation, but we shouldn't
578 : * reflect any information about the sensitive Bearer token back to the
579 : * client; log at COMMERROR instead.
580 : */
581 : static const char *
582 0 : validate_token_format(const char *header)
583 : {
584 : size_t span;
585 : const char *token;
586 : static const char *const b64token_allowed_set =
587 : "abcdefghijklmnopqrstuvwxyz"
588 : "ABCDEFGHIJKLMNOPQRSTUVWXYZ"
589 : "0123456789-._~+/";
590 :
591 : /* Missing auth headers should be handled by the caller. */
592 : Assert(header);
593 : /* Empty auth (discovery) should be handled before calling validate(). */
594 : Assert(header[0] != '\0');
595 :
596 0 : if (pg_strncasecmp(header, BEARER_SCHEME, strlen(BEARER_SCHEME)))
597 : {
598 0 : ereport(COMMERROR,
599 : errcode(ERRCODE_PROTOCOL_VIOLATION),
600 : errmsg("malformed OAuth bearer token"),
601 : errdetail_log("Client response indicated a non-Bearer authentication scheme."));
602 0 : return NULL;
603 : }
604 :
605 : /* Pull the bearer token out of the auth value. */
606 0 : token = header + strlen(BEARER_SCHEME);
607 :
608 : /* Swallow any additional spaces. */
609 0 : while (*token == ' ')
610 0 : token++;
611 :
612 : /* Tokens must not be empty. */
613 0 : if (!*token)
614 : {
615 0 : ereport(COMMERROR,
616 : errcode(ERRCODE_PROTOCOL_VIOLATION),
617 : errmsg("malformed OAuth bearer token"),
618 : errdetail_log("Bearer token is empty."));
619 0 : return NULL;
620 : }
621 :
622 : /*
623 : * Make sure the token contains only allowed characters. Tokens may end
624 : * with any number of '=' characters.
625 : */
626 0 : span = strspn(token, b64token_allowed_set);
627 0 : while (token[span] == '=')
628 0 : span++;
629 :
630 0 : if (token[span] != '\0')
631 : {
632 : /*
633 : * This error message could be more helpful by printing the
634 : * problematic character(s), but that'd be a bit like printing a piece
635 : * of someone's password into the logs.
636 : */
637 0 : ereport(COMMERROR,
638 : errcode(ERRCODE_PROTOCOL_VIOLATION),
639 : errmsg("malformed OAuth bearer token"),
640 : errdetail_log("Bearer token is not in the correct format."));
641 0 : return NULL;
642 : }
643 :
644 0 : return token;
645 : }
646 :
647 : /*
648 : * Checks that the "auth" kvpair in the client response contains a syntactically
649 : * valid Bearer token, then passes it along to the loaded validator module for
650 : * authorization. Returns true if validation succeeds.
651 : */
652 : static bool
653 0 : validate(Port *port, const char *auth, const char **logdetail)
654 : {
655 : int map_status;
656 : ValidatorModuleResult *ret;
657 : const char *token;
658 : bool status;
659 :
660 : /* Ensure that we have a correct token to validate */
661 0 : if (!(token = validate_token_format(auth)))
662 0 : return false;
663 :
664 : /*
665 : * Ensure that we have a validation library loaded, this should always be
666 : * the case and an error here is indicative of a bug.
667 : */
668 0 : if (!ValidatorCallbacks || !ValidatorCallbacks->validate_cb)
669 0 : ereport(FATAL,
670 : errcode(ERRCODE_INTERNAL_ERROR),
671 : errmsg("validation of OAuth token requested without a validator loaded"));
672 :
673 : /* Call the validation function from the validator module */
674 0 : ret = palloc0_object(ValidatorModuleResult);
675 0 : if (!ValidatorCallbacks->validate_cb(validator_module_state, token,
676 0 : port->user_name, ret))
677 : {
678 0 : ereport(WARNING,
679 : errcode(ERRCODE_INTERNAL_ERROR),
680 : errmsg("internal error in OAuth validator module"),
681 : ret->error_detail ? errdetail_log("%s", ret->error_detail) : 0);
682 :
683 0 : *logdetail = ret->error_detail;
684 0 : return false;
685 : }
686 :
687 : /*
688 : * Log any authentication results even if the token isn't authorized; it
689 : * might be useful for auditing or troubleshooting.
690 : */
691 0 : if (ret->authn_id)
692 0 : set_authn_id(port, ret->authn_id);
693 :
694 0 : if (!ret->authorized)
695 : {
696 0 : if (ret->error_detail)
697 0 : *logdetail = ret->error_detail;
698 : else
699 0 : *logdetail = _("Validator failed to authorize the provided token.");
700 :
701 0 : status = false;
702 0 : goto cleanup;
703 : }
704 :
705 0 : if (port->hba->oauth_skip_usermap)
706 : {
707 : /*
708 : * If the validator is our authorization authority, we're done.
709 : * Authentication may or may not have been performed depending on the
710 : * validator implementation; all that matters is that the validator
711 : * says the user can log in with the target role.
712 : */
713 0 : status = true;
714 0 : goto cleanup;
715 : }
716 :
717 : /* Make sure the validator authenticated the user. */
718 0 : if (ret->authn_id == NULL || ret->authn_id[0] == '\0')
719 : {
720 0 : *logdetail = _("Validator provided no identity.");
721 :
722 0 : status = false;
723 0 : goto cleanup;
724 : }
725 :
726 : /* Finally, check the user map. */
727 0 : map_status = check_usermap(port->hba->usermap, port->user_name,
728 : MyClientConnectionInfo.authn_id, false);
729 0 : status = (map_status == STATUS_OK);
730 :
731 0 : cleanup:
732 :
733 : /*
734 : * Clear and free the validation result from the validator module once
735 : * we're done with it.
736 : */
737 0 : if (ret->authn_id != NULL)
738 0 : pfree(ret->authn_id);
739 0 : pfree(ret);
740 :
741 0 : return status;
742 : }
743 :
744 : /*
745 : * load_validator_library
746 : *
747 : * Load the configured validator library in order to perform token validation.
748 : * There is no built-in fallback since validation is implementation specific. If
749 : * no validator library is configured, or if it fails to load, then error out
750 : * since token validation won't be possible.
751 : */
752 : static void
753 0 : load_validator_library(const char *libname)
754 : {
755 : OAuthValidatorModuleInit validator_init;
756 : MemoryContextCallback *mcb;
757 :
758 : /*
759 : * The presence, and validity, of libname has already been established by
760 : * check_oauth_validator so we don't need to perform more than Assert
761 : * level checking here.
762 : */
763 : Assert(libname && *libname);
764 :
765 0 : validator_init = (OAuthValidatorModuleInit)
766 0 : load_external_function(libname, "_PG_oauth_validator_module_init",
767 : false, NULL);
768 :
769 : /*
770 : * The validator init function is required since it will set the callbacks
771 : * for the validator library.
772 : */
773 0 : if (validator_init == NULL)
774 0 : ereport(ERROR,
775 : errmsg("%s module \"%s\" must define the symbol %s",
776 : "OAuth validator", libname, "_PG_oauth_validator_module_init"));
777 :
778 0 : ValidatorCallbacks = (*validator_init) ();
779 : Assert(ValidatorCallbacks);
780 :
781 : /*
782 : * Check the magic number, to protect against break-glass scenarios where
783 : * the ABI must change within a major version. load_external_function()
784 : * already checks for compatibility across major versions.
785 : */
786 0 : if (ValidatorCallbacks->magic != PG_OAUTH_VALIDATOR_MAGIC)
787 0 : ereport(ERROR,
788 : errmsg("%s module \"%s\": magic number mismatch",
789 : "OAuth validator", libname),
790 : errdetail("Server has magic number 0x%08X, module has 0x%08X.",
791 : PG_OAUTH_VALIDATOR_MAGIC, ValidatorCallbacks->magic));
792 :
793 : /*
794 : * Make sure all required callbacks are present in the ValidatorCallbacks
795 : * structure. Right now only the validation callback is required.
796 : */
797 0 : if (ValidatorCallbacks->validate_cb == NULL)
798 0 : ereport(ERROR,
799 : errmsg("%s module \"%s\" must provide a %s callback",
800 : "OAuth validator", libname, "validate_cb"));
801 :
802 : /* Allocate memory for validator library private state data */
803 0 : validator_module_state = palloc0_object(ValidatorModuleState);
804 0 : validator_module_state->sversion = PG_VERSION_NUM;
805 :
806 0 : if (ValidatorCallbacks->startup_cb != NULL)
807 0 : ValidatorCallbacks->startup_cb(validator_module_state);
808 :
809 : /* Shut down the library before cleaning up its state. */
810 0 : mcb = palloc0_object(MemoryContextCallback);
811 0 : mcb->func = shutdown_validator_library;
812 :
813 0 : MemoryContextRegisterResetCallback(CurrentMemoryContext, mcb);
814 0 : }
815 :
816 : /*
817 : * Call the validator module's shutdown callback, if one is provided. This is
818 : * invoked during memory context reset.
819 : */
820 : static void
821 0 : shutdown_validator_library(void *arg)
822 : {
823 0 : if (ValidatorCallbacks->shutdown_cb != NULL)
824 0 : ValidatorCallbacks->shutdown_cb(validator_module_state);
825 0 : }
826 :
827 : /*
828 : * Ensure an OAuth validator named in the HBA is permitted by the configuration.
829 : *
830 : * If the validator is currently unset and exactly one library is declared in
831 : * oauth_validator_libraries, then that library will be used as the validator.
832 : * Otherwise the name must be present in the list of oauth_validator_libraries.
833 : */
834 : bool
835 0 : check_oauth_validator(HbaLine *hbaline, int elevel, char **err_msg)
836 : {
837 0 : int line_num = hbaline->linenumber;
838 0 : const char *file_name = hbaline->sourcefile;
839 : char *rawstring;
840 0 : List *elemlist = NIL;
841 :
842 0 : *err_msg = NULL;
843 :
844 0 : if (oauth_validator_libraries_string[0] == '\0')
845 : {
846 0 : ereport(elevel,
847 : errcode(ERRCODE_CONFIG_FILE_ERROR),
848 : errmsg("oauth_validator_libraries must be set for authentication method %s",
849 : "oauth"),
850 : errcontext("line %d of configuration file \"%s\"",
851 : line_num, file_name));
852 0 : *err_msg = psprintf("oauth_validator_libraries must be set for authentication method %s",
853 : "oauth");
854 0 : return false;
855 : }
856 :
857 : /* SplitDirectoriesString needs a modifiable copy */
858 0 : rawstring = pstrdup(oauth_validator_libraries_string);
859 :
860 0 : if (!SplitDirectoriesString(rawstring, ',', &elemlist))
861 : {
862 : /* syntax error in list */
863 0 : ereport(elevel,
864 : errcode(ERRCODE_CONFIG_FILE_ERROR),
865 : errmsg("invalid list syntax in parameter \"%s\"",
866 : "oauth_validator_libraries"));
867 0 : *err_msg = psprintf("invalid list syntax in parameter \"%s\"",
868 : "oauth_validator_libraries");
869 0 : goto done;
870 : }
871 :
872 0 : if (!hbaline->oauth_validator)
873 : {
874 0 : if (elemlist->length == 1)
875 : {
876 0 : hbaline->oauth_validator = pstrdup(linitial(elemlist));
877 0 : goto done;
878 : }
879 :
880 0 : ereport(elevel,
881 : errcode(ERRCODE_CONFIG_FILE_ERROR),
882 : errmsg("authentication method \"oauth\" requires argument \"validator\" to be set when oauth_validator_libraries contains multiple options"),
883 : errcontext("line %d of configuration file \"%s\"",
884 : line_num, file_name));
885 0 : *err_msg = "authentication method \"oauth\" requires argument \"validator\" to be set when oauth_validator_libraries contains multiple options";
886 0 : goto done;
887 : }
888 :
889 0 : foreach_ptr(char, allowed, elemlist)
890 : {
891 0 : if (strcmp(allowed, hbaline->oauth_validator) == 0)
892 0 : goto done;
893 : }
894 :
895 0 : ereport(elevel,
896 : errcode(ERRCODE_INVALID_PARAMETER_VALUE),
897 : errmsg("validator \"%s\" is not permitted by %s",
898 : hbaline->oauth_validator, "oauth_validator_libraries"),
899 : errcontext("line %d of configuration file \"%s\"",
900 : line_num, file_name));
901 0 : *err_msg = psprintf("validator \"%s\" is not permitted by %s",
902 : hbaline->oauth_validator, "oauth_validator_libraries");
903 :
904 0 : done:
905 0 : list_free_deep(elemlist);
906 0 : pfree(rawstring);
907 :
908 0 : return (*err_msg == NULL);
909 : }
|