LCOV - code coverage report
Current view: top level - src/test/modules/oauth_validator - oauth_hook_client.c (source / functions) Coverage Total Hit
Test: PostgreSQL 19devel Lines: 0.0 % 151 0
Test Date: 2026-03-23 01:16:05 Functions: 0.0 % 5 0
Legend: Lines:     hit not hit

            Line data    Source code
       1              : /*-------------------------------------------------------------------------
       2              :  *
       3              :  * oauth_hook_client.c
       4              :  *      Test driver for t/002_client.pl, which verifies OAuth hook
       5              :  *      functionality in libpq.
       6              :  *
       7              :  * Portions Copyright (c) 1996-2026, PostgreSQL Global Development Group
       8              :  * Portions Copyright (c) 1994, Regents of the University of California
       9              :  *
      10              :  *
      11              :  * IDENTIFICATION
      12              :  *      src/test/modules/oauth_validator/oauth_hook_client.c
      13              :  *
      14              :  *-------------------------------------------------------------------------
      15              :  */
      16              : 
      17              : #include "postgres_fe.h"
      18              : 
      19              : #include <sys/socket.h>
      20              : 
      21              : #include "getopt_long.h"
      22              : #include "libpq-fe.h"
      23              : 
      24              : static int  handle_auth_data(PGauthData type, PGconn *conn, void *data);
      25              : static PostgresPollingStatusType async_cb(PGconn *conn,
      26              :                                           PGoauthBearerRequest *req,
      27              :                                           pgsocket *altsock);
      28              : static PostgresPollingStatusType misbehave_cb(PGconn *conn,
      29              :                                               PGoauthBearerRequest *req,
      30              :                                               pgsocket *altsock);
      31              : 
      32              : static void
      33            0 : usage(char *argv[])
      34              : {
      35            0 :     printf("usage: %s [flags] CONNINFO\n\n", argv[0]);
      36              : 
      37            0 :     printf("recognized flags:\n");
      38            0 :     printf("  -h, --help              show this message\n");
      39            0 :     printf("  -v VERSION              select the hook API version (default 2)\n");
      40            0 :     printf("  --expected-scope SCOPE  fail if received scopes do not match SCOPE\n");
      41            0 :     printf("  --expected-uri URI      fail if received configuration link does not match URI\n");
      42            0 :     printf("  --expected-issuer ISS   fail if received issuer does not match ISS (v2 only)\n");
      43            0 :     printf("  --misbehave=MODE        have the hook fail required postconditions\n"
      44              :            "                          (MODEs: no-hook, fail-async, no-token, no-socket)\n");
      45            0 :     printf("  --no-hook               don't install OAuth hooks\n");
      46            0 :     printf("  --hang-forever          don't ever return a token (combine with connect_timeout)\n");
      47            0 :     printf("  --token TOKEN           use the provided TOKEN value\n");
      48            0 :     printf("  --error ERRMSG          fail instead, with the given ERRMSG (v2 only)\n");
      49            0 :     printf("  --stress-async          busy-loop on PQconnectPoll rather than polling\n");
      50            0 : }
      51              : 
      52              : /* --options */
      53              : static bool no_hook = false;
      54              : static bool hang_forever = false;
      55              : static bool stress_async = false;
      56              : static const char *expected_uri = NULL;
      57              : static const char *expected_issuer = NULL;
      58              : static const char *expected_scope = NULL;
      59              : static const char *misbehave_mode = NULL;
      60              : static char *token = NULL;
      61              : static char *errmsg = NULL;
      62              : static int  hook_version = PQAUTHDATA_OAUTH_BEARER_TOKEN_V2;
      63              : 
      64              : int
      65            0 : main(int argc, char *argv[])
      66              : {
      67              :     static const struct option long_options[] = {
      68              :         {"help", no_argument, NULL, 'h'},
      69              : 
      70              :         {"expected-scope", required_argument, NULL, 1000},
      71              :         {"expected-uri", required_argument, NULL, 1001},
      72              :         {"no-hook", no_argument, NULL, 1002},
      73              :         {"token", required_argument, NULL, 1003},
      74              :         {"hang-forever", no_argument, NULL, 1004},
      75              :         {"misbehave", required_argument, NULL, 1005},
      76              :         {"stress-async", no_argument, NULL, 1006},
      77              :         {"expected-issuer", required_argument, NULL, 1007},
      78              :         {"error", required_argument, NULL, 1008},
      79              :         {0}
      80              :     };
      81              : 
      82              :     const char *conninfo;
      83              :     PGconn     *conn;
      84              :     int         c;
      85              : 
      86            0 :     while ((c = getopt_long(argc, argv, "hv:", long_options, NULL)) != -1)
      87              :     {
      88            0 :         switch (c)
      89              :         {
      90            0 :             case 'h':
      91            0 :                 usage(argv);
      92            0 :                 return 0;
      93              : 
      94            0 :             case 'v':
      95            0 :                 if (strcmp(optarg, "1") == 0)
      96            0 :                     hook_version = PQAUTHDATA_OAUTH_BEARER_TOKEN;
      97            0 :                 else if (strcmp(optarg, "2") == 0)
      98            0 :                     hook_version = PQAUTHDATA_OAUTH_BEARER_TOKEN_V2;
      99              :                 else
     100              :                 {
     101            0 :                     usage(argv);
     102            0 :                     return 1;
     103              :                 }
     104            0 :                 break;
     105              : 
     106            0 :             case 1000:          /* --expected-scope */
     107            0 :                 expected_scope = optarg;
     108            0 :                 break;
     109              : 
     110            0 :             case 1001:          /* --expected-uri */
     111            0 :                 expected_uri = optarg;
     112            0 :                 break;
     113              : 
     114            0 :             case 1002:          /* --no-hook */
     115            0 :                 no_hook = true;
     116            0 :                 break;
     117              : 
     118            0 :             case 1003:          /* --token */
     119            0 :                 token = optarg;
     120            0 :                 break;
     121              : 
     122            0 :             case 1004:          /* --hang-forever */
     123            0 :                 hang_forever = true;
     124            0 :                 break;
     125              : 
     126            0 :             case 1005:          /* --misbehave */
     127            0 :                 misbehave_mode = optarg;
     128            0 :                 break;
     129              : 
     130            0 :             case 1006:          /* --stress-async */
     131            0 :                 stress_async = true;
     132            0 :                 break;
     133              : 
     134            0 :             case 1007:          /* --expected-issuer */
     135            0 :                 expected_issuer = optarg;
     136            0 :                 break;
     137              : 
     138            0 :             case 1008:          /* --error */
     139            0 :                 errmsg = optarg;
     140            0 :                 break;
     141              : 
     142            0 :             default:
     143            0 :                 usage(argv);
     144            0 :                 return 1;
     145              :         }
     146              :     }
     147              : 
     148            0 :     if (argc != optind + 1)
     149              :     {
     150            0 :         usage(argv);
     151            0 :         return 1;
     152              :     }
     153              : 
     154            0 :     conninfo = argv[optind];
     155              : 
     156              :     /* Set up our OAuth hooks. */
     157            0 :     PQsetAuthDataHook(handle_auth_data);
     158              : 
     159              :     /* Connect. (All the actual work is in the hook.) */
     160            0 :     if (stress_async)
     161              :     {
     162              :         /*
     163              :          * Perform an asynchronous connection, busy-looping on PQconnectPoll()
     164              :          * without actually waiting on socket events. This stresses code paths
     165              :          * that rely on asynchronous work to be done before continuing with
     166              :          * the next step in the flow.
     167              :          */
     168              :         PostgresPollingStatusType res;
     169              : 
     170            0 :         conn = PQconnectStart(conninfo);
     171              : 
     172              :         do
     173              :         {
     174            0 :             res = PQconnectPoll(conn);
     175            0 :         } while (res != PGRES_POLLING_FAILED && res != PGRES_POLLING_OK);
     176              :     }
     177              :     else
     178              :     {
     179              :         /* Perform a standard synchronous connection. */
     180            0 :         conn = PQconnectdb(conninfo);
     181              :     }
     182              : 
     183            0 :     if (PQstatus(conn) != CONNECTION_OK)
     184              :     {
     185            0 :         fprintf(stderr, "connection to database failed: %s\n",
     186              :                 PQerrorMessage(conn));
     187            0 :         PQfinish(conn);
     188            0 :         return 1;
     189              :     }
     190              : 
     191            0 :     printf("connection succeeded\n");
     192            0 :     PQfinish(conn);
     193            0 :     return 0;
     194              : }
     195              : 
     196              : /*
     197              :  * PQauthDataHook implementation. Replaces the default client flow by handling
     198              :  * PQAUTHDATA_OAUTH_BEARER_TOKEN[_V2].
     199              :  */
     200              : static int
     201            0 : handle_auth_data(PGauthData type, PGconn *conn, void *data)
     202              : {
     203              :     PGoauthBearerRequest *req;
     204            0 :     PGoauthBearerRequestV2 *req2 = NULL;
     205              : 
     206              :     Assert(hook_version == PQAUTHDATA_OAUTH_BEARER_TOKEN ||
     207              :            hook_version == PQAUTHDATA_OAUTH_BEARER_TOKEN_V2);
     208              : 
     209            0 :     if (no_hook || type != hook_version)
     210            0 :         return 0;
     211              : 
     212            0 :     req = data;
     213            0 :     if (type == PQAUTHDATA_OAUTH_BEARER_TOKEN_V2)
     214            0 :         req2 = data;
     215              : 
     216            0 :     if (hang_forever)
     217              :     {
     218              :         /* Start asynchronous processing. */
     219            0 :         req->async = async_cb;
     220            0 :         return 1;
     221              :     }
     222              : 
     223            0 :     if (misbehave_mode)
     224              :     {
     225            0 :         if (strcmp(misbehave_mode, "no-hook") != 0)
     226            0 :             req->async = misbehave_cb;
     227            0 :         return 1;
     228              :     }
     229              : 
     230            0 :     if (expected_uri)
     231              :     {
     232            0 :         if (!req->openid_configuration)
     233              :         {
     234            0 :             fprintf(stderr, "expected URI \"%s\", got NULL\n", expected_uri);
     235            0 :             return -1;
     236              :         }
     237              : 
     238            0 :         if (strcmp(expected_uri, req->openid_configuration) != 0)
     239              :         {
     240            0 :             fprintf(stderr, "expected URI \"%s\", got \"%s\"\n", expected_uri, req->openid_configuration);
     241            0 :             return -1;
     242              :         }
     243              :     }
     244              : 
     245            0 :     if (expected_scope)
     246              :     {
     247            0 :         if (!req->scope)
     248              :         {
     249            0 :             fprintf(stderr, "expected scope \"%s\", got NULL\n", expected_scope);
     250            0 :             return -1;
     251              :         }
     252              : 
     253            0 :         if (strcmp(expected_scope, req->scope) != 0)
     254              :         {
     255            0 :             fprintf(stderr, "expected scope \"%s\", got \"%s\"\n", expected_scope, req->scope);
     256            0 :             return -1;
     257              :         }
     258              :     }
     259              : 
     260            0 :     if (expected_issuer)
     261              :     {
     262            0 :         if (!req2)
     263              :         {
     264            0 :             fprintf(stderr, "--expected-issuer cannot be combined with -v1\n");
     265            0 :             return -1;
     266              :         }
     267              : 
     268            0 :         if (!req2->issuer)
     269              :         {
     270            0 :             fprintf(stderr, "expected issuer \"%s\", got NULL\n", expected_issuer);
     271            0 :             return -1;
     272              :         }
     273              : 
     274            0 :         if (strcmp(expected_issuer, req2->issuer) != 0)
     275              :         {
     276            0 :             fprintf(stderr, "expected issuer \"%s\", got \"%s\"\n", expected_issuer, req2->issuer);
     277            0 :             return -1;
     278              :         }
     279              :     }
     280              : 
     281            0 :     if (errmsg)
     282              :     {
     283            0 :         if (token)
     284              :         {
     285            0 :             fprintf(stderr, "--error cannot be combined with --token\n");
     286            0 :             return -1;
     287              :         }
     288            0 :         else if (!req2)
     289              :         {
     290            0 :             fprintf(stderr, "--error cannot be combined with -v1\n");
     291            0 :             return -1;
     292              :         }
     293              : 
     294            0 :         req2->error = errmsg;
     295            0 :         return -1;
     296              :     }
     297              : 
     298            0 :     req->token = token;
     299            0 :     return 1;
     300              : }
     301              : 
     302              : static PostgresPollingStatusType
     303            0 : async_cb(PGconn *conn, PGoauthBearerRequest *req, pgsocket *altsock)
     304              : {
     305            0 :     if (hang_forever)
     306              :     {
     307              :         /*
     308              :          * This code tests that nothing is interfering with libpq's handling
     309              :          * of connect_timeout.
     310              :          */
     311              :         static pgsocket sock = PGINVALID_SOCKET;
     312              : 
     313            0 :         if (sock == PGINVALID_SOCKET)
     314              :         {
     315              :             /* First call. Create an unbound socket to wait on. */
     316              : #ifdef WIN32
     317              :             WSADATA     wsaData;
     318              :             int         err;
     319              : 
     320              :             err = WSAStartup(MAKEWORD(2, 2), &wsaData);
     321              :             if (err)
     322              :             {
     323              :                 perror("WSAStartup failed");
     324              :                 return PGRES_POLLING_FAILED;
     325              :             }
     326              : #endif
     327            0 :             sock = socket(AF_INET, SOCK_DGRAM, 0);
     328            0 :             if (sock == PGINVALID_SOCKET)
     329              :             {
     330            0 :                 perror("failed to create datagram socket");
     331            0 :                 return PGRES_POLLING_FAILED;
     332              :             }
     333              :         }
     334              : 
     335              :         /* Make libpq wait on the (unreadable) socket. */
     336            0 :         *altsock = sock;
     337            0 :         return PGRES_POLLING_READING;
     338              :     }
     339              : 
     340            0 :     req->token = token;
     341            0 :     return PGRES_POLLING_OK;
     342              : }
     343              : 
     344              : static PostgresPollingStatusType
     345            0 : misbehave_cb(PGconn *conn, PGoauthBearerRequest *req, pgsocket *altsock)
     346              : {
     347            0 :     if (strcmp(misbehave_mode, "fail-async") == 0)
     348              :     {
     349              :         /* Just fail "normally". */
     350            0 :         if (errmsg)
     351              :         {
     352              :             PGoauthBearerRequestV2 *req2;
     353              : 
     354            0 :             if (hook_version == PQAUTHDATA_OAUTH_BEARER_TOKEN)
     355              :             {
     356            0 :                 fprintf(stderr, "--error cannot be combined with -v1\n");
     357            0 :                 exit(1);
     358              :             }
     359              : 
     360            0 :             req2 = (PGoauthBearerRequestV2 *) req;
     361            0 :             req2->error = errmsg;
     362              :         }
     363              : 
     364            0 :         return PGRES_POLLING_FAILED;
     365              :     }
     366            0 :     else if (strcmp(misbehave_mode, "no-token") == 0)
     367              :     {
     368              :         /* Callbacks must assign req->token before returning OK. */
     369            0 :         return PGRES_POLLING_OK;
     370              :     }
     371            0 :     else if (strcmp(misbehave_mode, "no-socket") == 0)
     372              :     {
     373              :         /* Callbacks must assign *altsock before asking for polling. */
     374            0 :         return PGRES_POLLING_READING;
     375              :     }
     376              :     else
     377              :     {
     378            0 :         fprintf(stderr, "unrecognized --misbehave mode: %s\n", misbehave_mode);
     379            0 :         exit(1);
     380              :     }
     381              : }
        

Generated by: LCOV version 2.0-1