LCOV - code coverage report
Current view: top level - src/backend/statistics - mvdistinct.c (source / functions) Hit Total Coverage
Test: PostgreSQL 19devel Lines: 188 201 93.5 %
Date: 2026-01-26 19:17:09 Functions: 15 15 100.0 %
Legend: Lines: hit not hit

          Line data    Source code
       1             : /*-------------------------------------------------------------------------
       2             :  *
       3             :  * mvdistinct.c
       4             :  *    POSTGRES multivariate ndistinct coefficients
       5             :  *
       6             :  * Estimating number of groups in a combination of columns (e.g. for GROUP BY)
       7             :  * is tricky, and the estimation error is often significant.
       8             : 
       9             :  * The multivariate ndistinct coefficients address this by storing ndistinct
      10             :  * estimates for combinations of the user-specified columns.  So for example
      11             :  * given a statistics object on three columns (a,b,c), this module estimates
      12             :  * and stores n-distinct for (a,b), (a,c), (b,c) and (a,b,c).  The per-column
      13             :  * estimates are already available in pg_statistic.
      14             :  *
      15             :  *
      16             :  * Portions Copyright (c) 1996-2026, PostgreSQL Global Development Group
      17             :  * Portions Copyright (c) 1994, Regents of the University of California
      18             :  *
      19             :  * IDENTIFICATION
      20             :  *    src/backend/statistics/mvdistinct.c
      21             :  *
      22             :  *-------------------------------------------------------------------------
      23             :  */
      24             : #include "postgres.h"
      25             : 
      26             : #include <math.h>
      27             : 
      28             : #include "catalog/pg_statistic_ext.h"
      29             : #include "catalog/pg_statistic_ext_data.h"
      30             : #include "statistics/extended_stats_internal.h"
      31             : #include "utils/syscache.h"
      32             : #include "utils/typcache.h"
      33             : #include "varatt.h"
      34             : 
      35             : static double ndistinct_for_combination(double totalrows, StatsBuildData *data,
      36             :                                         int k, int *combination);
      37             : static double estimate_ndistinct(double totalrows, int numrows, int d, int f1);
      38             : static int  n_choose_k(int n, int k);
      39             : static int  num_combinations(int n);
      40             : 
      41             : /* size of the struct header fields (magic, type, nitems) */
      42             : #define SizeOfHeader        (3 * sizeof(uint32))
      43             : 
      44             : /* size of a serialized ndistinct item (coefficient, natts, atts) */
      45             : #define SizeOfItem(natts) \
      46             :     (sizeof(double) + sizeof(int) + (natts) * sizeof(AttrNumber))
      47             : 
      48             : /* minimal size of a ndistinct item (with two attributes) */
      49             : #define MinSizeOfItem   SizeOfItem(2)
      50             : 
      51             : /* minimal size of mvndistinct, when all items are minimal */
      52             : #define MinSizeOfItems(nitems)  \
      53             :     (SizeOfHeader + (nitems) * MinSizeOfItem)
      54             : 
      55             : /* Combination generator API */
      56             : 
      57             : /* internal state for generator of k-combinations of n elements */
      58             : typedef struct CombinationGenerator
      59             : {
      60             :     int         k;              /* size of the combination */
      61             :     int         n;              /* total number of elements */
      62             :     int         current;        /* index of the next combination to return */
      63             :     int         ncombinations;  /* number of combinations (size of array) */
      64             :     int        *combinations;   /* array of pre-built combinations */
      65             : } CombinationGenerator;
      66             : 
      67             : static CombinationGenerator *generator_init(int n, int k);
      68             : static void generator_free(CombinationGenerator *state);
      69             : static int *generator_next(CombinationGenerator *state);
      70             : static void generate_combinations(CombinationGenerator *state);
      71             : 
      72             : 
      73             : /*
      74             :  * statext_ndistinct_build
      75             :  *      Compute ndistinct coefficient for the combination of attributes.
      76             :  *
      77             :  * This computes the ndistinct estimate using the same estimator used
      78             :  * in analyze.c and then computes the coefficient.
      79             :  *
      80             :  * To handle expressions easily, we treat them as system attributes with
      81             :  * negative attnums, and offset everything by number of expressions to
      82             :  * allow using Bitmapsets.
      83             :  */
      84             : MVNDistinct *
      85         222 : statext_ndistinct_build(double totalrows, StatsBuildData *data)
      86             : {
      87             :     MVNDistinct *result;
      88             :     int         k;
      89             :     int         itemcnt;
      90         222 :     int         numattrs = data->nattnums;
      91         222 :     int         numcombs = num_combinations(numattrs);
      92             : 
      93         222 :     result = palloc(offsetof(MVNDistinct, items) +
      94         222 :                     numcombs * sizeof(MVNDistinctItem));
      95         222 :     result->magic = STATS_NDISTINCT_MAGIC;
      96         222 :     result->type = STATS_NDISTINCT_TYPE_BASIC;
      97         222 :     result->nitems = numcombs;
      98             : 
      99         222 :     itemcnt = 0;
     100         552 :     for (k = 2; k <= numattrs; k++)
     101             :     {
     102             :         int        *combination;
     103             :         CombinationGenerator *generator;
     104             : 
     105             :         /* generate combinations of K out of N elements */
     106         330 :         generator = generator_init(numattrs, k);
     107             : 
     108         996 :         while ((combination = generator_next(generator)))
     109             :         {
     110         666 :             MVNDistinctItem *item = &result->items[itemcnt];
     111             :             int         j;
     112             : 
     113         666 :             item->attributes = palloc_array(AttrNumber, k);
     114         666 :             item->nattributes = k;
     115             : 
     116             :             /* translate the indexes to attnums */
     117        2226 :             for (j = 0; j < k; j++)
     118             :             {
     119        1560 :                 item->attributes[j] = data->attnums[combination[j]];
     120             : 
     121             :                 Assert(AttributeNumberIsValid(item->attributes[j]));
     122             :             }
     123             : 
     124         666 :             item->ndistinct =
     125         666 :                 ndistinct_for_combination(totalrows, data, k, combination);
     126             : 
     127         666 :             itemcnt++;
     128             :             Assert(itemcnt <= result->nitems);
     129             :         }
     130             : 
     131         330 :         generator_free(generator);
     132             :     }
     133             : 
     134             :     /* must consume exactly the whole output array */
     135             :     Assert(itemcnt == result->nitems);
     136             : 
     137         222 :     return result;
     138             : }
     139             : 
     140             : /*
     141             :  * statext_ndistinct_load
     142             :  *      Load the ndistinct value for the indicated pg_statistic_ext tuple
     143             :  */
     144             : MVNDistinct *
     145         426 : statext_ndistinct_load(Oid mvoid, bool inh)
     146             : {
     147             :     MVNDistinct *result;
     148             :     bool        isnull;
     149             :     Datum       ndist;
     150             :     HeapTuple   htup;
     151             : 
     152         426 :     htup = SearchSysCache2(STATEXTDATASTXOID,
     153             :                            ObjectIdGetDatum(mvoid), BoolGetDatum(inh));
     154         426 :     if (!HeapTupleIsValid(htup))
     155           0 :         elog(ERROR, "cache lookup failed for statistics object %u", mvoid);
     156             : 
     157         426 :     ndist = SysCacheGetAttr(STATEXTDATASTXOID, htup,
     158             :                             Anum_pg_statistic_ext_data_stxdndistinct, &isnull);
     159         426 :     if (isnull)
     160           0 :         elog(ERROR,
     161             :              "requested statistics kind \"%c\" is not yet built for statistics object %u",
     162             :              STATS_EXT_NDISTINCT, mvoid);
     163             : 
     164         426 :     result = statext_ndistinct_deserialize(DatumGetByteaPP(ndist));
     165             : 
     166         426 :     ReleaseSysCache(htup);
     167             : 
     168         426 :     return result;
     169             : }
     170             : 
     171             : /*
     172             :  * statext_ndistinct_serialize
     173             :  *      serialize ndistinct to the on-disk bytea format
     174             :  */
     175             : bytea *
     176         264 : statext_ndistinct_serialize(MVNDistinct *ndistinct)
     177             : {
     178             :     int         i;
     179             :     bytea      *output;
     180             :     char       *tmp;
     181             :     Size        len;
     182             : 
     183             :     Assert(ndistinct->magic == STATS_NDISTINCT_MAGIC);
     184             :     Assert(ndistinct->type == STATS_NDISTINCT_TYPE_BASIC);
     185             : 
     186             :     /*
     187             :      * Base size is size of scalar fields in the struct, plus one base struct
     188             :      * for each item, including number of items for each.
     189             :      */
     190         264 :     len = VARHDRSZ + SizeOfHeader;
     191             : 
     192             :     /* and also include space for the actual attribute numbers */
     193         990 :     for (i = 0; i < ndistinct->nitems; i++)
     194             :     {
     195             :         int         nmembers;
     196             : 
     197         726 :         nmembers = ndistinct->items[i].nattributes;
     198             :         Assert(nmembers >= 2);
     199             : 
     200         726 :         len += SizeOfItem(nmembers);
     201             :     }
     202             : 
     203         264 :     output = (bytea *) palloc(len);
     204         264 :     SET_VARSIZE(output, len);
     205             : 
     206         264 :     tmp = VARDATA(output);
     207             : 
     208             :     /* Store the base struct values (magic, type, nitems) */
     209         264 :     memcpy(tmp, &ndistinct->magic, sizeof(uint32));
     210         264 :     tmp += sizeof(uint32);
     211         264 :     memcpy(tmp, &ndistinct->type, sizeof(uint32));
     212         264 :     tmp += sizeof(uint32);
     213         264 :     memcpy(tmp, &ndistinct->nitems, sizeof(uint32));
     214         264 :     tmp += sizeof(uint32);
     215             : 
     216             :     /*
     217             :      * store number of attributes and attribute numbers for each entry
     218             :      */
     219         990 :     for (i = 0; i < ndistinct->nitems; i++)
     220             :     {
     221         726 :         MVNDistinctItem item = ndistinct->items[i];
     222         726 :         int         nmembers = item.nattributes;
     223             : 
     224         726 :         memcpy(tmp, &item.ndistinct, sizeof(double));
     225         726 :         tmp += sizeof(double);
     226         726 :         memcpy(tmp, &nmembers, sizeof(int));
     227         726 :         tmp += sizeof(int);
     228             : 
     229         726 :         memcpy(tmp, item.attributes, sizeof(AttrNumber) * nmembers);
     230         726 :         tmp += nmembers * sizeof(AttrNumber);
     231             : 
     232             :         /* protect against overflows */
     233             :         Assert(tmp <= ((char *) output + len));
     234             :     }
     235             : 
     236             :     /* check we used exactly the expected space */
     237             :     Assert(tmp == ((char *) output + len));
     238             : 
     239         264 :     return output;
     240             : }
     241             : 
     242             : /*
     243             :  * statext_ndistinct_deserialize
     244             :  *      Read an on-disk bytea format MVNDistinct to in-memory format
     245             :  */
     246             : MVNDistinct *
     247         492 : statext_ndistinct_deserialize(bytea *data)
     248             : {
     249             :     int         i;
     250             :     Size        minimum_size;
     251             :     MVNDistinct ndist;
     252             :     MVNDistinct *ndistinct;
     253             :     char       *tmp;
     254             : 
     255         492 :     if (data == NULL)
     256           0 :         return NULL;
     257             : 
     258             :     /* we expect at least the basic fields of MVNDistinct struct */
     259         492 :     if (VARSIZE_ANY_EXHDR(data) < SizeOfHeader)
     260           0 :         elog(ERROR, "invalid MVNDistinct size %zu (expected at least %zu)",
     261             :              VARSIZE_ANY_EXHDR(data), SizeOfHeader);
     262             : 
     263             :     /* initialize pointer to the data part (skip the varlena header) */
     264         492 :     tmp = VARDATA_ANY(data);
     265             : 
     266             :     /* read the header fields and perform basic sanity checks */
     267         492 :     memcpy(&ndist.magic, tmp, sizeof(uint32));
     268         492 :     tmp += sizeof(uint32);
     269         492 :     memcpy(&ndist.type, tmp, sizeof(uint32));
     270         492 :     tmp += sizeof(uint32);
     271         492 :     memcpy(&ndist.nitems, tmp, sizeof(uint32));
     272         492 :     tmp += sizeof(uint32);
     273             : 
     274         492 :     if (ndist.magic != STATS_NDISTINCT_MAGIC)
     275           0 :         elog(ERROR, "invalid ndistinct magic %08x (expected %08x)",
     276             :              ndist.magic, STATS_NDISTINCT_MAGIC);
     277         492 :     if (ndist.type != STATS_NDISTINCT_TYPE_BASIC)
     278           0 :         elog(ERROR, "invalid ndistinct type %d (expected %d)",
     279             :              ndist.type, STATS_NDISTINCT_TYPE_BASIC);
     280         492 :     if (ndist.nitems == 0)
     281           0 :         elog(ERROR, "invalid zero-length item array in MVNDistinct");
     282             : 
     283             :     /* what minimum bytea size do we expect for those parameters */
     284         492 :     minimum_size = MinSizeOfItems(ndist.nitems);
     285         492 :     if (VARSIZE_ANY_EXHDR(data) < minimum_size)
     286           0 :         elog(ERROR, "invalid MVNDistinct size %zu (expected at least %zu)",
     287             :              VARSIZE_ANY_EXHDR(data), minimum_size);
     288             : 
     289             :     /*
     290             :      * Allocate space for the ndistinct items (no space for each item's
     291             :      * attnos: those live in bitmapsets allocated separately)
     292             :      */
     293         492 :     ndistinct = palloc0(MAXALIGN(offsetof(MVNDistinct, items)) +
     294         492 :                         (ndist.nitems * sizeof(MVNDistinctItem)));
     295         492 :     ndistinct->magic = ndist.magic;
     296         492 :     ndistinct->type = ndist.type;
     297         492 :     ndistinct->nitems = ndist.nitems;
     298             : 
     299        2514 :     for (i = 0; i < ndistinct->nitems; i++)
     300             :     {
     301        2022 :         MVNDistinctItem *item = &ndistinct->items[i];
     302             : 
     303             :         /* ndistinct value */
     304        2022 :         memcpy(&item->ndistinct, tmp, sizeof(double));
     305        2022 :         tmp += sizeof(double);
     306             : 
     307             :         /* number of attributes */
     308        2022 :         memcpy(&item->nattributes, tmp, sizeof(int));
     309        2022 :         tmp += sizeof(int);
     310             :         Assert((item->nattributes >= 2) && (item->nattributes <= STATS_MAX_DIMENSIONS));
     311             : 
     312             :         item->attributes
     313        2022 :             = (AttrNumber *) palloc(item->nattributes * sizeof(AttrNumber));
     314             : 
     315        2022 :         memcpy(item->attributes, tmp, sizeof(AttrNumber) * item->nattributes);
     316        2022 :         tmp += sizeof(AttrNumber) * item->nattributes;
     317             : 
     318             :         /* still within the bytea */
     319             :         Assert(tmp <= ((char *) data + VARSIZE_ANY(data)));
     320             :     }
     321             : 
     322             :     /* we should have consumed the whole bytea exactly */
     323             :     Assert(tmp == ((char *) data + VARSIZE_ANY(data)));
     324             : 
     325         492 :     return ndistinct;
     326             : }
     327             : 
     328             : /*
     329             :  * Free allocations of a MVNDistinct.
     330             :  */
     331             : void
     332          18 : statext_ndistinct_free(MVNDistinct *ndistinct)
     333             : {
     334          36 :     for (int i = 0; i < ndistinct->nitems; i++)
     335          18 :         pfree(ndistinct->items[i].attributes);
     336          18 :     pfree(ndistinct);
     337          18 : }
     338             : 
     339             : /*
     340             :  * Validate a set of MVNDistincts against the extended statistics object
     341             :  * definition.
     342             :  *
     343             :  * Every MVNDistinctItem must be checked to ensure that the attnums in the
     344             :  * attributes list correspond to attnums/expressions defined by the extended
     345             :  * statistics object.
     346             :  *
     347             :  * Positive attnums are attributes which must be found in the stxkeys,
     348             :  * while negative attnums correspond to an expression number, no attribute
     349             :  * number can be below (0 - numexprs).
     350             :  */
     351             : bool
     352          18 : statext_ndistinct_validate(const MVNDistinct *ndistinct,
     353             :                            const int2vector *stxkeys,
     354             :                            int numexprs, int elevel)
     355             : {
     356          18 :     int         attnum_expr_lowbound = 0 - numexprs;
     357             : 
     358             :     /* Scan through each MVNDistinct entry */
     359          30 :     for (int i = 0; i < ndistinct->nitems; i++)
     360             :     {
     361          18 :         MVNDistinctItem item = ndistinct->items[i];
     362             : 
     363             :         /*
     364             :          * Cross-check each attribute in a MVNDistinct entry with the extended
     365             :          * stats object definition.
     366             :          */
     367          42 :         for (int j = 0; j < item.nattributes; j++)
     368             :         {
     369          30 :             AttrNumber  attnum = item.attributes[j];
     370          30 :             bool        ok = false;
     371             : 
     372          30 :             if (attnum > 0)
     373             :             {
     374             :                 /* attribute number in stxkeys */
     375          54 :                 for (int k = 0; k < stxkeys->dim1; k++)
     376             :                 {
     377          48 :                     if (attnum == stxkeys->values[k])
     378             :                     {
     379          24 :                         ok = true;
     380          24 :                         break;
     381             :                     }
     382             :                 }
     383             :             }
     384           0 :             else if ((attnum < 0) && (attnum >= attnum_expr_lowbound))
     385             :             {
     386             :                 /* attribute number for an expression */
     387           0 :                 ok = true;
     388             :             }
     389             : 
     390          30 :             if (!ok)
     391             :             {
     392           6 :                 ereport(elevel,
     393             :                         (errcode(ERRCODE_INVALID_TEXT_REPRESENTATION),
     394             :                          errmsg("could not validate \"%s\" object: invalid attribute number %d found",
     395             :                                 "pg_ndistinct", attnum)));
     396           6 :                 return false;
     397             :             }
     398             :         }
     399             :     }
     400             : 
     401          12 :     return true;
     402             : }
     403             : 
     404             : /*
     405             :  * ndistinct_for_combination
     406             :  *      Estimates number of distinct values in a combination of columns.
     407             :  *
     408             :  * This uses the same ndistinct estimator as compute_scalar_stats() in
     409             :  * ANALYZE, i.e.,
     410             :  *      n*d / (n - f1 + f1*n/N)
     411             :  *
     412             :  * except that instead of values in a single column we are dealing with
     413             :  * combination of multiple columns.
     414             :  */
     415             : static double
     416         666 : ndistinct_for_combination(double totalrows, StatsBuildData *data,
     417             :                           int k, int *combination)
     418             : {
     419             :     int         i,
     420             :                 j;
     421             :     int         f1,
     422             :                 cnt,
     423             :                 d;
     424             :     bool       *isnull;
     425             :     Datum      *values;
     426             :     SortItem   *items;
     427             :     MultiSortSupport mss;
     428         666 :     int         numrows = data->numrows;
     429             : 
     430         666 :     mss = multi_sort_init(k);
     431             : 
     432             :     /*
     433             :      * In order to determine the number of distinct elements, create separate
     434             :      * values[]/isnull[] arrays with all the data we have, then sort them
     435             :      * using the specified column combination as dimensions.  We could try to
     436             :      * sort in place, but it'd probably be more complex and bug-prone.
     437             :      */
     438         666 :     items = palloc_array(SortItem, numrows);
     439         666 :     values = palloc0_array(Datum, numrows * k);
     440         666 :     isnull = palloc0_array(bool, numrows * k);
     441             : 
     442      964244 :     for (i = 0; i < numrows; i++)
     443             :     {
     444      963578 :         items[i].values = &values[i * k];
     445      963578 :         items[i].isnull = &isnull[i * k];
     446             :     }
     447             : 
     448             :     /*
     449             :      * For each dimension, set up sort-support and fill in the values from the
     450             :      * sample data.
     451             :      *
     452             :      * We use the column data types' default sort operators and collations;
     453             :      * perhaps at some point it'd be worth using column-specific collations?
     454             :      */
     455        2226 :     for (i = 0; i < k; i++)
     456             :     {
     457             :         Oid         typid;
     458             :         TypeCacheEntry *type;
     459        1560 :         Oid         collid = InvalidOid;
     460        1560 :         VacAttrStats *colstat = data->stats[combination[i]];
     461             : 
     462        1560 :         typid = colstat->attrtypid;
     463        1560 :         collid = colstat->attrcollid;
     464             : 
     465        1560 :         type = lookup_type_cache(typid, TYPECACHE_LT_OPR);
     466        1560 :         if (type->lt_opr == InvalidOid) /* shouldn't happen */
     467           0 :             elog(ERROR, "cache lookup failed for ordering operator for type %u",
     468             :                  typid);
     469             : 
     470             :         /* prepare the sort function for this dimension */
     471        1560 :         multi_sort_add_dimension(mss, i, type->lt_opr, collid);
     472             : 
     473             :         /* accumulate all the data for this dimension into the arrays */
     474     2228854 :         for (j = 0; j < numrows; j++)
     475             :         {
     476     2227294 :             items[j].values[i] = data->values[combination[i]][j];
     477     2227294 :             items[j].isnull[i] = data->nulls[combination[i]][j];
     478             :         }
     479             :     }
     480             : 
     481             :     /* We can sort the array now ... */
     482         666 :     qsort_interruptible(items, numrows, sizeof(SortItem),
     483             :                         multi_sort_compare, mss);
     484             : 
     485             :     /* ... and count the number of distinct combinations */
     486             : 
     487         666 :     f1 = 0;
     488         666 :     cnt = 1;
     489         666 :     d = 1;
     490      963578 :     for (i = 1; i < numrows; i++)
     491             :     {
     492      962912 :         if (multi_sort_compare(&items[i], &items[i - 1], mss) != 0)
     493             :         {
     494      284676 :             if (cnt == 1)
     495      146248 :                 f1 += 1;
     496             : 
     497      284676 :             d++;
     498      284676 :             cnt = 0;
     499             :         }
     500             : 
     501      962912 :         cnt += 1;
     502             :     }
     503             : 
     504         666 :     if (cnt == 1)
     505         210 :         f1 += 1;
     506             : 
     507         666 :     return estimate_ndistinct(totalrows, numrows, d, f1);
     508             : }
     509             : 
     510             : /* The Duj1 estimator (already used in analyze.c). */
     511             : static double
     512         666 : estimate_ndistinct(double totalrows, int numrows, int d, int f1)
     513             : {
     514             :     double      numer,
     515             :                 denom,
     516             :                 ndistinct;
     517             : 
     518         666 :     numer = (double) numrows * (double) d;
     519             : 
     520         666 :     denom = (double) (numrows - f1) +
     521         666 :         (double) f1 * (double) numrows / totalrows;
     522             : 
     523         666 :     ndistinct = numer / denom;
     524             : 
     525             :     /* Clamp to sane range in case of roundoff error */
     526         666 :     if (ndistinct < (double) d)
     527           0 :         ndistinct = (double) d;
     528             : 
     529         666 :     if (ndistinct > totalrows)
     530           0 :         ndistinct = totalrows;
     531             : 
     532         666 :     return floor(ndistinct + 0.5);
     533             : }
     534             : 
     535             : /*
     536             :  * n_choose_k
     537             :  *      computes binomial coefficients using an algorithm that is both
     538             :  *      efficient and prevents overflows
     539             :  */
     540             : static int
     541         330 : n_choose_k(int n, int k)
     542             : {
     543             :     int         d,
     544             :                 r;
     545             : 
     546             :     Assert((k > 0) && (n >= k));
     547             : 
     548             :     /* use symmetry of the binomial coefficients */
     549         330 :     k = Min(k, n - k);
     550             : 
     551         330 :     r = 1;
     552         468 :     for (d = 1; d <= k; ++d)
     553             :     {
     554         138 :         r *= n--;
     555         138 :         r /= d;
     556             :     }
     557             : 
     558         330 :     return r;
     559             : }
     560             : 
     561             : /*
     562             :  * num_combinations
     563             :  *      number of combinations, excluding single-value combinations
     564             :  */
     565             : static int
     566         222 : num_combinations(int n)
     567             : {
     568         222 :     return (1 << n) - (n + 1);
     569             : }
     570             : 
     571             : /*
     572             :  * generator_init
     573             :  *      initialize the generator of combinations
     574             :  *
     575             :  * The generator produces combinations of K elements in the interval (0..N).
     576             :  * We prebuild all the combinations in this method, which is simpler than
     577             :  * generating them on the fly.
     578             :  */
     579             : static CombinationGenerator *
     580         330 : generator_init(int n, int k)
     581             : {
     582             :     CombinationGenerator *state;
     583             : 
     584             :     Assert((n >= k) && (k > 0));
     585             : 
     586             :     /* allocate the generator state as a single chunk of memory */
     587         330 :     state = palloc_object(CombinationGenerator);
     588             : 
     589         330 :     state->ncombinations = n_choose_k(n, k);
     590             : 
     591             :     /* pre-allocate space for all combinations */
     592         330 :     state->combinations = palloc_array(int, k * state->ncombinations);
     593             : 
     594         330 :     state->current = 0;
     595         330 :     state->k = k;
     596         330 :     state->n = n;
     597             : 
     598             :     /* now actually pre-generate all the combinations of K elements */
     599         330 :     generate_combinations(state);
     600             : 
     601             :     /* make sure we got the expected number of combinations */
     602             :     Assert(state->current == state->ncombinations);
     603             : 
     604             :     /* reset the number, so we start with the first one */
     605         330 :     state->current = 0;
     606             : 
     607         330 :     return state;
     608             : }
     609             : 
     610             : /*
     611             :  * generator_next
     612             :  *      returns the next combination from the prebuilt list
     613             :  *
     614             :  * Returns a combination of K array indexes (0 .. N), as specified to
     615             :  * generator_init), or NULL when there are no more combination.
     616             :  */
     617             : static int *
     618         996 : generator_next(CombinationGenerator *state)
     619             : {
     620         996 :     if (state->current == state->ncombinations)
     621         330 :         return NULL;
     622             : 
     623         666 :     return &state->combinations[state->k * state->current++];
     624             : }
     625             : 
     626             : /*
     627             :  * generator_free
     628             :  *      free the internal state of the generator
     629             :  *
     630             :  * Releases the generator internal state (pre-built combinations).
     631             :  */
     632             : static void
     633         330 : generator_free(CombinationGenerator *state)
     634             : {
     635         330 :     pfree(state->combinations);
     636         330 :     pfree(state);
     637         330 : }
     638             : 
     639             : /*
     640             :  * generate_combinations_recurse
     641             :  *      given a prefix, generate all possible combinations
     642             :  *
     643             :  * Given a prefix (first few elements of the combination), generate following
     644             :  * elements recursively. We generate the combinations in lexicographic order,
     645             :  * which eliminates permutations of the same combination.
     646             :  */
     647             : static void
     648        2556 : generate_combinations_recurse(CombinationGenerator *state,
     649             :                               int index, int start, int *current)
     650             : {
     651             :     /* If we haven't filled all the elements, simply recurse. */
     652        2556 :     if (index < state->k)
     653             :     {
     654             :         int         i;
     655             : 
     656             :         /*
     657             :          * The values have to be in ascending order, so make sure we start
     658             :          * with the value passed by parameter.
     659             :          */
     660             : 
     661        4116 :         for (i = start; i < state->n; i++)
     662             :         {
     663        2226 :             current[index] = i;
     664        2226 :             generate_combinations_recurse(state, (index + 1), (i + 1), current);
     665             :         }
     666             : 
     667        1890 :         return;
     668             :     }
     669             :     else
     670             :     {
     671             :         /* we got a valid combination, add it to the array */
     672         666 :         memcpy(&state->combinations[(state->k * state->current)],
     673         666 :                current, state->k * sizeof(int));
     674         666 :         state->current++;
     675             :     }
     676             : }
     677             : 
     678             : /*
     679             :  * generate_combinations
     680             :  *      generate all k-combinations of N elements
     681             :  */
     682             : static void
     683         330 : generate_combinations(CombinationGenerator *state)
     684             : {
     685         330 :     int        *current = palloc0_array(int, state->k);
     686             : 
     687         330 :     generate_combinations_recurse(state, 0, 0, current);
     688             : 
     689         330 :     pfree(current);
     690         330 : }

Generated by: LCOV version 1.16