Line data Source code
1 : /*-------------------------------------------------------------------------
2 : *
3 : * oauth_hook_client.c
4 : * Test driver for t/002_client.pl, which verifies OAuth hook
5 : * functionality in libpq.
6 : *
7 : * Portions Copyright (c) 1996-2026, PostgreSQL Global Development Group
8 : * Portions Copyright (c) 1994, Regents of the University of California
9 : *
10 : *
11 : * IDENTIFICATION
12 : * src/test/modules/oauth_validator/oauth_hook_client.c
13 : *
14 : *-------------------------------------------------------------------------
15 : */
16 :
17 : #include "postgres_fe.h"
18 :
19 : #include <sys/socket.h>
20 :
21 : #include "getopt_long.h"
22 : #include "libpq-fe.h"
23 :
24 : static int handle_auth_data(PGauthData type, PGconn *conn, void *data);
25 : static PostgresPollingStatusType async_cb(PGconn *conn,
26 : PGoauthBearerRequest *req,
27 : pgsocket *altsock);
28 : static PostgresPollingStatusType misbehave_cb(PGconn *conn,
29 : PGoauthBearerRequest *req,
30 : pgsocket *altsock);
31 :
32 : static void
33 0 : usage(char *argv[])
34 : {
35 0 : printf("usage: %s [flags] CONNINFO\n\n", argv[0]);
36 :
37 0 : printf("recognized flags:\n");
38 0 : printf(" -h, --help show this message\n");
39 0 : printf(" -v VERSION select the hook API version (default 2)\n");
40 0 : printf(" --expected-scope SCOPE fail if received scopes do not match SCOPE\n");
41 0 : printf(" --expected-uri URI fail if received configuration link does not match URI\n");
42 0 : printf(" --expected-issuer ISS fail if received issuer does not match ISS (v2 only)\n");
43 0 : printf(" --misbehave=MODE have the hook fail required postconditions\n"
44 : " (MODEs: no-hook, fail-async, no-token, no-socket)\n");
45 0 : printf(" --no-hook don't install OAuth hooks\n");
46 0 : printf(" --hang-forever don't ever return a token (combine with connect_timeout)\n");
47 0 : printf(" --token TOKEN use the provided TOKEN value\n");
48 0 : printf(" --error ERRMSG fail instead, with the given ERRMSG (v2 only)\n");
49 0 : printf(" --stress-async busy-loop on PQconnectPoll rather than polling\n");
50 0 : }
51 :
52 : /* --options */
53 : static bool no_hook = false;
54 : static bool hang_forever = false;
55 : static bool stress_async = false;
56 : static const char *expected_uri = NULL;
57 : static const char *expected_issuer = NULL;
58 : static const char *expected_scope = NULL;
59 : static const char *misbehave_mode = NULL;
60 : static char *token = NULL;
61 : static char *errmsg = NULL;
62 : static int hook_version = PQAUTHDATA_OAUTH_BEARER_TOKEN_V2;
63 :
64 : int
65 0 : main(int argc, char *argv[])
66 : {
67 : static const struct option long_options[] = {
68 : {"help", no_argument, NULL, 'h'},
69 :
70 : {"expected-scope", required_argument, NULL, 1000},
71 : {"expected-uri", required_argument, NULL, 1001},
72 : {"no-hook", no_argument, NULL, 1002},
73 : {"token", required_argument, NULL, 1003},
74 : {"hang-forever", no_argument, NULL, 1004},
75 : {"misbehave", required_argument, NULL, 1005},
76 : {"stress-async", no_argument, NULL, 1006},
77 : {"expected-issuer", required_argument, NULL, 1007},
78 : {"error", required_argument, NULL, 1008},
79 : {0}
80 : };
81 :
82 : const char *conninfo;
83 : PGconn *conn;
84 : int c;
85 :
86 0 : while ((c = getopt_long(argc, argv, "hv:", long_options, NULL)) != -1)
87 : {
88 0 : switch (c)
89 : {
90 0 : case 'h':
91 0 : usage(argv);
92 0 : return 0;
93 :
94 0 : case 'v':
95 0 : if (strcmp(optarg, "1") == 0)
96 0 : hook_version = PQAUTHDATA_OAUTH_BEARER_TOKEN;
97 0 : else if (strcmp(optarg, "2") == 0)
98 0 : hook_version = PQAUTHDATA_OAUTH_BEARER_TOKEN_V2;
99 : else
100 : {
101 0 : usage(argv);
102 0 : return 1;
103 : }
104 0 : break;
105 :
106 0 : case 1000: /* --expected-scope */
107 0 : expected_scope = optarg;
108 0 : break;
109 :
110 0 : case 1001: /* --expected-uri */
111 0 : expected_uri = optarg;
112 0 : break;
113 :
114 0 : case 1002: /* --no-hook */
115 0 : no_hook = true;
116 0 : break;
117 :
118 0 : case 1003: /* --token */
119 0 : token = optarg;
120 0 : break;
121 :
122 0 : case 1004: /* --hang-forever */
123 0 : hang_forever = true;
124 0 : break;
125 :
126 0 : case 1005: /* --misbehave */
127 0 : misbehave_mode = optarg;
128 0 : break;
129 :
130 0 : case 1006: /* --stress-async */
131 0 : stress_async = true;
132 0 : break;
133 :
134 0 : case 1007: /* --expected-issuer */
135 0 : expected_issuer = optarg;
136 0 : break;
137 :
138 0 : case 1008: /* --error */
139 0 : errmsg = optarg;
140 0 : break;
141 :
142 0 : default:
143 0 : usage(argv);
144 0 : return 1;
145 : }
146 : }
147 :
148 0 : if (argc != optind + 1)
149 : {
150 0 : usage(argv);
151 0 : return 1;
152 : }
153 :
154 0 : conninfo = argv[optind];
155 :
156 : /* Set up our OAuth hooks. */
157 0 : PQsetAuthDataHook(handle_auth_data);
158 :
159 : /* Connect. (All the actual work is in the hook.) */
160 0 : if (stress_async)
161 : {
162 : /*
163 : * Perform an asynchronous connection, busy-looping on PQconnectPoll()
164 : * without actually waiting on socket events. This stresses code paths
165 : * that rely on asynchronous work to be done before continuing with
166 : * the next step in the flow.
167 : */
168 : PostgresPollingStatusType res;
169 :
170 0 : conn = PQconnectStart(conninfo);
171 :
172 : do
173 : {
174 0 : res = PQconnectPoll(conn);
175 0 : } while (res != PGRES_POLLING_FAILED && res != PGRES_POLLING_OK);
176 : }
177 : else
178 : {
179 : /* Perform a standard synchronous connection. */
180 0 : conn = PQconnectdb(conninfo);
181 : }
182 :
183 0 : if (PQstatus(conn) != CONNECTION_OK)
184 : {
185 0 : fprintf(stderr, "connection to database failed: %s\n",
186 : PQerrorMessage(conn));
187 0 : PQfinish(conn);
188 0 : return 1;
189 : }
190 :
191 0 : printf("connection succeeded\n");
192 0 : PQfinish(conn);
193 0 : return 0;
194 : }
195 :
196 : /*
197 : * PQauthDataHook implementation. Replaces the default client flow by handling
198 : * PQAUTHDATA_OAUTH_BEARER_TOKEN[_V2].
199 : */
200 : static int
201 0 : handle_auth_data(PGauthData type, PGconn *conn, void *data)
202 : {
203 : PGoauthBearerRequest *req;
204 0 : PGoauthBearerRequestV2 *req2 = NULL;
205 :
206 : Assert(hook_version == PQAUTHDATA_OAUTH_BEARER_TOKEN ||
207 : hook_version == PQAUTHDATA_OAUTH_BEARER_TOKEN_V2);
208 :
209 0 : if (no_hook || type != hook_version)
210 0 : return 0;
211 :
212 0 : req = data;
213 0 : if (type == PQAUTHDATA_OAUTH_BEARER_TOKEN_V2)
214 0 : req2 = data;
215 :
216 0 : if (hang_forever)
217 : {
218 : /* Start asynchronous processing. */
219 0 : req->async = async_cb;
220 0 : return 1;
221 : }
222 :
223 0 : if (misbehave_mode)
224 : {
225 0 : if (strcmp(misbehave_mode, "no-hook") != 0)
226 0 : req->async = misbehave_cb;
227 0 : return 1;
228 : }
229 :
230 0 : if (expected_uri)
231 : {
232 0 : if (!req->openid_configuration)
233 : {
234 0 : fprintf(stderr, "expected URI \"%s\", got NULL\n", expected_uri);
235 0 : return -1;
236 : }
237 :
238 0 : if (strcmp(expected_uri, req->openid_configuration) != 0)
239 : {
240 0 : fprintf(stderr, "expected URI \"%s\", got \"%s\"\n", expected_uri, req->openid_configuration);
241 0 : return -1;
242 : }
243 : }
244 :
245 0 : if (expected_scope)
246 : {
247 0 : if (!req->scope)
248 : {
249 0 : fprintf(stderr, "expected scope \"%s\", got NULL\n", expected_scope);
250 0 : return -1;
251 : }
252 :
253 0 : if (strcmp(expected_scope, req->scope) != 0)
254 : {
255 0 : fprintf(stderr, "expected scope \"%s\", got \"%s\"\n", expected_scope, req->scope);
256 0 : return -1;
257 : }
258 : }
259 :
260 0 : if (expected_issuer)
261 : {
262 0 : if (!req2)
263 : {
264 0 : fprintf(stderr, "--expected-issuer cannot be combined with -v1\n");
265 0 : return -1;
266 : }
267 :
268 0 : if (!req2->issuer)
269 : {
270 0 : fprintf(stderr, "expected issuer \"%s\", got NULL\n", expected_issuer);
271 0 : return -1;
272 : }
273 :
274 0 : if (strcmp(expected_issuer, req2->issuer) != 0)
275 : {
276 0 : fprintf(stderr, "expected issuer \"%s\", got \"%s\"\n", expected_issuer, req2->issuer);
277 0 : return -1;
278 : }
279 : }
280 :
281 0 : if (errmsg)
282 : {
283 0 : if (token)
284 : {
285 0 : fprintf(stderr, "--error cannot be combined with --token\n");
286 0 : return -1;
287 : }
288 0 : else if (!req2)
289 : {
290 0 : fprintf(stderr, "--error cannot be combined with -v1\n");
291 0 : return -1;
292 : }
293 :
294 0 : req2->error = errmsg;
295 0 : return -1;
296 : }
297 :
298 0 : req->token = token;
299 0 : return 1;
300 : }
301 :
302 : static PostgresPollingStatusType
303 0 : async_cb(PGconn *conn, PGoauthBearerRequest *req, pgsocket *altsock)
304 : {
305 0 : if (hang_forever)
306 : {
307 : /*
308 : * This code tests that nothing is interfering with libpq's handling
309 : * of connect_timeout.
310 : */
311 : static pgsocket sock = PGINVALID_SOCKET;
312 :
313 0 : if (sock == PGINVALID_SOCKET)
314 : {
315 : /* First call. Create an unbound socket to wait on. */
316 : #ifdef WIN32
317 : WSADATA wsaData;
318 : int err;
319 :
320 : err = WSAStartup(MAKEWORD(2, 2), &wsaData);
321 : if (err)
322 : {
323 : perror("WSAStartup failed");
324 : return PGRES_POLLING_FAILED;
325 : }
326 : #endif
327 0 : sock = socket(AF_INET, SOCK_DGRAM, 0);
328 0 : if (sock == PGINVALID_SOCKET)
329 : {
330 0 : perror("failed to create datagram socket");
331 0 : return PGRES_POLLING_FAILED;
332 : }
333 : }
334 :
335 : /* Make libpq wait on the (unreadable) socket. */
336 0 : *altsock = sock;
337 0 : return PGRES_POLLING_READING;
338 : }
339 :
340 0 : req->token = token;
341 0 : return PGRES_POLLING_OK;
342 : }
343 :
344 : static PostgresPollingStatusType
345 0 : misbehave_cb(PGconn *conn, PGoauthBearerRequest *req, pgsocket *altsock)
346 : {
347 0 : if (strcmp(misbehave_mode, "fail-async") == 0)
348 : {
349 : /* Just fail "normally". */
350 0 : if (errmsg)
351 : {
352 : PGoauthBearerRequestV2 *req2;
353 :
354 0 : if (hook_version == PQAUTHDATA_OAUTH_BEARER_TOKEN)
355 : {
356 0 : fprintf(stderr, "--error cannot be combined with -v1\n");
357 0 : exit(1);
358 : }
359 :
360 0 : req2 = (PGoauthBearerRequestV2 *) req;
361 0 : req2->error = errmsg;
362 : }
363 :
364 0 : return PGRES_POLLING_FAILED;
365 : }
366 0 : else if (strcmp(misbehave_mode, "no-token") == 0)
367 : {
368 : /* Callbacks must assign req->token before returning OK. */
369 0 : return PGRES_POLLING_OK;
370 : }
371 0 : else if (strcmp(misbehave_mode, "no-socket") == 0)
372 : {
373 : /* Callbacks must assign *altsock before asking for polling. */
374 0 : return PGRES_POLLING_READING;
375 : }
376 : else
377 : {
378 0 : fprintf(stderr, "unrecognized --misbehave mode: %s\n", misbehave_mode);
379 0 : exit(1);
380 : }
381 : }
|