LCOV - code coverage report
Current view: top level - src/backend/statistics - mvdistinct.c (source / functions) Hit Total Coverage
Test: PostgreSQL 18devel Lines: 182 199 91.5 %
Date: 2024-10-10 05:14:53 Functions: 14 17 82.4 %
Legend: Lines: hit not hit

          Line data    Source code
       1             : /*-------------------------------------------------------------------------
       2             :  *
       3             :  * mvdistinct.c
       4             :  *    POSTGRES multivariate ndistinct coefficients
       5             :  *
       6             :  * Estimating number of groups in a combination of columns (e.g. for GROUP BY)
       7             :  * is tricky, and the estimation error is often significant.
       8             : 
       9             :  * The multivariate ndistinct coefficients address this by storing ndistinct
      10             :  * estimates for combinations of the user-specified columns.  So for example
      11             :  * given a statistics object on three columns (a,b,c), this module estimates
      12             :  * and stores n-distinct for (a,b), (a,c), (b,c) and (a,b,c).  The per-column
      13             :  * estimates are already available in pg_statistic.
      14             :  *
      15             :  *
      16             :  * Portions Copyright (c) 1996-2024, PostgreSQL Global Development Group
      17             :  * Portions Copyright (c) 1994, Regents of the University of California
      18             :  *
      19             :  * IDENTIFICATION
      20             :  *    src/backend/statistics/mvdistinct.c
      21             :  *
      22             :  *-------------------------------------------------------------------------
      23             :  */
      24             : #include "postgres.h"
      25             : 
      26             : #include <math.h>
      27             : 
      28             : #include "catalog/pg_statistic_ext.h"
      29             : #include "catalog/pg_statistic_ext_data.h"
      30             : #include "lib/stringinfo.h"
      31             : #include "statistics/extended_stats_internal.h"
      32             : #include "statistics/statistics.h"
      33             : #include "utils/fmgrprotos.h"
      34             : #include "utils/syscache.h"
      35             : #include "utils/typcache.h"
      36             : #include "varatt.h"
      37             : 
      38             : static double ndistinct_for_combination(double totalrows, StatsBuildData *data,
      39             :                                         int k, int *combination);
      40             : static double estimate_ndistinct(double totalrows, int numrows, int d, int f1);
      41             : static int  n_choose_k(int n, int k);
      42             : static int  num_combinations(int n);
      43             : 
      44             : /* size of the struct header fields (magic, type, nitems) */
      45             : #define SizeOfHeader        (3 * sizeof(uint32))
      46             : 
      47             : /* size of a serialized ndistinct item (coefficient, natts, atts) */
      48             : #define SizeOfItem(natts) \
      49             :     (sizeof(double) + sizeof(int) + (natts) * sizeof(AttrNumber))
      50             : 
      51             : /* minimal size of a ndistinct item (with two attributes) */
      52             : #define MinSizeOfItem   SizeOfItem(2)
      53             : 
      54             : /* minimal size of mvndistinct, when all items are minimal */
      55             : #define MinSizeOfItems(nitems)  \
      56             :     (SizeOfHeader + (nitems) * MinSizeOfItem)
      57             : 
      58             : /* Combination generator API */
      59             : 
      60             : /* internal state for generator of k-combinations of n elements */
      61             : typedef struct CombinationGenerator
      62             : {
      63             :     int         k;              /* size of the combination */
      64             :     int         n;              /* total number of elements */
      65             :     int         current;        /* index of the next combination to return */
      66             :     int         ncombinations;  /* number of combinations (size of array) */
      67             :     int        *combinations;   /* array of pre-built combinations */
      68             : } CombinationGenerator;
      69             : 
      70             : static CombinationGenerator *generator_init(int n, int k);
      71             : static void generator_free(CombinationGenerator *state);
      72             : static int *generator_next(CombinationGenerator *state);
      73             : static void generate_combinations(CombinationGenerator *state);
      74             : 
      75             : 
      76             : /*
      77             :  * statext_ndistinct_build
      78             :  *      Compute ndistinct coefficient for the combination of attributes.
      79             :  *
      80             :  * This computes the ndistinct estimate using the same estimator used
      81             :  * in analyze.c and then computes the coefficient.
      82             :  *
      83             :  * To handle expressions easily, we treat them as system attributes with
      84             :  * negative attnums, and offset everything by number of expressions to
      85             :  * allow using Bitmapsets.
      86             :  */
      87             : MVNDistinct *
      88         168 : statext_ndistinct_build(double totalrows, StatsBuildData *data)
      89             : {
      90             :     MVNDistinct *result;
      91             :     int         k;
      92             :     int         itemcnt;
      93         168 :     int         numattrs = data->nattnums;
      94         168 :     int         numcombs = num_combinations(numattrs);
      95             : 
      96         168 :     result = palloc(offsetof(MVNDistinct, items) +
      97         168 :                     numcombs * sizeof(MVNDistinctItem));
      98         168 :     result->magic = STATS_NDISTINCT_MAGIC;
      99         168 :     result->type = STATS_NDISTINCT_TYPE_BASIC;
     100         168 :     result->nitems = numcombs;
     101             : 
     102         168 :     itemcnt = 0;
     103         420 :     for (k = 2; k <= numattrs; k++)
     104             :     {
     105             :         int        *combination;
     106             :         CombinationGenerator *generator;
     107             : 
     108             :         /* generate combinations of K out of N elements */
     109         252 :         generator = generator_init(numattrs, k);
     110             : 
     111         768 :         while ((combination = generator_next(generator)))
     112             :         {
     113         516 :             MVNDistinctItem *item = &result->items[itemcnt];
     114             :             int         j;
     115             : 
     116         516 :             item->attributes = palloc(sizeof(AttrNumber) * k);
     117         516 :             item->nattributes = k;
     118             : 
     119             :             /* translate the indexes to attnums */
     120        1728 :             for (j = 0; j < k; j++)
     121             :             {
     122        1212 :                 item->attributes[j] = data->attnums[combination[j]];
     123             : 
     124             :                 Assert(AttributeNumberIsValid(item->attributes[j]));
     125             :             }
     126             : 
     127         516 :             item->ndistinct =
     128         516 :                 ndistinct_for_combination(totalrows, data, k, combination);
     129             : 
     130         516 :             itemcnt++;
     131             :             Assert(itemcnt <= result->nitems);
     132             :         }
     133             : 
     134         252 :         generator_free(generator);
     135             :     }
     136             : 
     137             :     /* must consume exactly the whole output array */
     138             :     Assert(itemcnt == result->nitems);
     139             : 
     140         168 :     return result;
     141             : }
     142             : 
     143             : /*
     144             :  * statext_ndistinct_load
     145             :  *      Load the ndistinct value for the indicated pg_statistic_ext tuple
     146             :  */
     147             : MVNDistinct *
     148         402 : statext_ndistinct_load(Oid mvoid, bool inh)
     149             : {
     150             :     MVNDistinct *result;
     151             :     bool        isnull;
     152             :     Datum       ndist;
     153             :     HeapTuple   htup;
     154             : 
     155         402 :     htup = SearchSysCache2(STATEXTDATASTXOID,
     156             :                            ObjectIdGetDatum(mvoid), BoolGetDatum(inh));
     157         402 :     if (!HeapTupleIsValid(htup))
     158           0 :         elog(ERROR, "cache lookup failed for statistics object %u", mvoid);
     159             : 
     160         402 :     ndist = SysCacheGetAttr(STATEXTDATASTXOID, htup,
     161             :                             Anum_pg_statistic_ext_data_stxdndistinct, &isnull);
     162         402 :     if (isnull)
     163           0 :         elog(ERROR,
     164             :              "requested statistics kind \"%c\" is not yet built for statistics object %u",
     165             :              STATS_EXT_NDISTINCT, mvoid);
     166             : 
     167         402 :     result = statext_ndistinct_deserialize(DatumGetByteaPP(ndist));
     168             : 
     169         402 :     ReleaseSysCache(htup);
     170             : 
     171         402 :     return result;
     172             : }
     173             : 
     174             : /*
     175             :  * statext_ndistinct_serialize
     176             :  *      serialize ndistinct to the on-disk bytea format
     177             :  */
     178             : bytea *
     179         168 : statext_ndistinct_serialize(MVNDistinct *ndistinct)
     180             : {
     181             :     int         i;
     182             :     bytea      *output;
     183             :     char       *tmp;
     184             :     Size        len;
     185             : 
     186             :     Assert(ndistinct->magic == STATS_NDISTINCT_MAGIC);
     187             :     Assert(ndistinct->type == STATS_NDISTINCT_TYPE_BASIC);
     188             : 
     189             :     /*
     190             :      * Base size is size of scalar fields in the struct, plus one base struct
     191             :      * for each item, including number of items for each.
     192             :      */
     193         168 :     len = VARHDRSZ + SizeOfHeader;
     194             : 
     195             :     /* and also include space for the actual attribute numbers */
     196         684 :     for (i = 0; i < ndistinct->nitems; i++)
     197             :     {
     198             :         int         nmembers;
     199             : 
     200         516 :         nmembers = ndistinct->items[i].nattributes;
     201             :         Assert(nmembers >= 2);
     202             : 
     203         516 :         len += SizeOfItem(nmembers);
     204             :     }
     205             : 
     206         168 :     output = (bytea *) palloc(len);
     207         168 :     SET_VARSIZE(output, len);
     208             : 
     209         168 :     tmp = VARDATA(output);
     210             : 
     211             :     /* Store the base struct values (magic, type, nitems) */
     212         168 :     memcpy(tmp, &ndistinct->magic, sizeof(uint32));
     213         168 :     tmp += sizeof(uint32);
     214         168 :     memcpy(tmp, &ndistinct->type, sizeof(uint32));
     215         168 :     tmp += sizeof(uint32);
     216         168 :     memcpy(tmp, &ndistinct->nitems, sizeof(uint32));
     217         168 :     tmp += sizeof(uint32);
     218             : 
     219             :     /*
     220             :      * store number of attributes and attribute numbers for each entry
     221             :      */
     222         684 :     for (i = 0; i < ndistinct->nitems; i++)
     223             :     {
     224         516 :         MVNDistinctItem item = ndistinct->items[i];
     225         516 :         int         nmembers = item.nattributes;
     226             : 
     227         516 :         memcpy(tmp, &item.ndistinct, sizeof(double));
     228         516 :         tmp += sizeof(double);
     229         516 :         memcpy(tmp, &nmembers, sizeof(int));
     230         516 :         tmp += sizeof(int);
     231             : 
     232         516 :         memcpy(tmp, item.attributes, sizeof(AttrNumber) * nmembers);
     233         516 :         tmp += nmembers * sizeof(AttrNumber);
     234             : 
     235             :         /* protect against overflows */
     236             :         Assert(tmp <= ((char *) output + len));
     237             :     }
     238             : 
     239             :     /* check we used exactly the expected space */
     240             :     Assert(tmp == ((char *) output + len));
     241             : 
     242         168 :     return output;
     243             : }
     244             : 
     245             : /*
     246             :  * statext_ndistinct_deserialize
     247             :  *      Read an on-disk bytea format MVNDistinct to in-memory format
     248             :  */
     249             : MVNDistinct *
     250         426 : statext_ndistinct_deserialize(bytea *data)
     251             : {
     252             :     int         i;
     253             :     Size        minimum_size;
     254             :     MVNDistinct ndist;
     255             :     MVNDistinct *ndistinct;
     256             :     char       *tmp;
     257             : 
     258         426 :     if (data == NULL)
     259           0 :         return NULL;
     260             : 
     261             :     /* we expect at least the basic fields of MVNDistinct struct */
     262         426 :     if (VARSIZE_ANY_EXHDR(data) < SizeOfHeader)
     263           0 :         elog(ERROR, "invalid MVNDistinct size %zu (expected at least %zu)",
     264             :              VARSIZE_ANY_EXHDR(data), SizeOfHeader);
     265             : 
     266             :     /* initialize pointer to the data part (skip the varlena header) */
     267         426 :     tmp = VARDATA_ANY(data);
     268             : 
     269             :     /* read the header fields and perform basic sanity checks */
     270         426 :     memcpy(&ndist.magic, tmp, sizeof(uint32));
     271         426 :     tmp += sizeof(uint32);
     272         426 :     memcpy(&ndist.type, tmp, sizeof(uint32));
     273         426 :     tmp += sizeof(uint32);
     274         426 :     memcpy(&ndist.nitems, tmp, sizeof(uint32));
     275         426 :     tmp += sizeof(uint32);
     276             : 
     277         426 :     if (ndist.magic != STATS_NDISTINCT_MAGIC)
     278           0 :         elog(ERROR, "invalid ndistinct magic %08x (expected %08x)",
     279             :              ndist.magic, STATS_NDISTINCT_MAGIC);
     280         426 :     if (ndist.type != STATS_NDISTINCT_TYPE_BASIC)
     281           0 :         elog(ERROR, "invalid ndistinct type %d (expected %d)",
     282             :              ndist.type, STATS_NDISTINCT_TYPE_BASIC);
     283         426 :     if (ndist.nitems == 0)
     284           0 :         elog(ERROR, "invalid zero-length item array in MVNDistinct");
     285             : 
     286             :     /* what minimum bytea size do we expect for those parameters */
     287         426 :     minimum_size = MinSizeOfItems(ndist.nitems);
     288         426 :     if (VARSIZE_ANY_EXHDR(data) < minimum_size)
     289           0 :         elog(ERROR, "invalid MVNDistinct size %zu (expected at least %zu)",
     290             :              VARSIZE_ANY_EXHDR(data), minimum_size);
     291             : 
     292             :     /*
     293             :      * Allocate space for the ndistinct items (no space for each item's
     294             :      * attnos: those live in bitmapsets allocated separately)
     295             :      */
     296         426 :     ndistinct = palloc0(MAXALIGN(offsetof(MVNDistinct, items)) +
     297         426 :                         (ndist.nitems * sizeof(MVNDistinctItem)));
     298         426 :     ndistinct->magic = ndist.magic;
     299         426 :     ndistinct->type = ndist.type;
     300         426 :     ndistinct->nitems = ndist.nitems;
     301             : 
     302        2310 :     for (i = 0; i < ndistinct->nitems; i++)
     303             :     {
     304        1884 :         MVNDistinctItem *item = &ndistinct->items[i];
     305             : 
     306             :         /* ndistinct value */
     307        1884 :         memcpy(&item->ndistinct, tmp, sizeof(double));
     308        1884 :         tmp += sizeof(double);
     309             : 
     310             :         /* number of attributes */
     311        1884 :         memcpy(&item->nattributes, tmp, sizeof(int));
     312        1884 :         tmp += sizeof(int);
     313             :         Assert((item->nattributes >= 2) && (item->nattributes <= STATS_MAX_DIMENSIONS));
     314             : 
     315             :         item->attributes
     316        1884 :             = (AttrNumber *) palloc(item->nattributes * sizeof(AttrNumber));
     317             : 
     318        1884 :         memcpy(item->attributes, tmp, sizeof(AttrNumber) * item->nattributes);
     319        1884 :         tmp += sizeof(AttrNumber) * item->nattributes;
     320             : 
     321             :         /* still within the bytea */
     322             :         Assert(tmp <= ((char *) data + VARSIZE_ANY(data)));
     323             :     }
     324             : 
     325             :     /* we should have consumed the whole bytea exactly */
     326             :     Assert(tmp == ((char *) data + VARSIZE_ANY(data)));
     327             : 
     328         426 :     return ndistinct;
     329             : }
     330             : 
     331             : /*
     332             :  * pg_ndistinct_in
     333             :  *      input routine for type pg_ndistinct
     334             :  *
     335             :  * pg_ndistinct is real enough to be a table column, but it has no
     336             :  * operations of its own, and disallows input (just like pg_node_tree).
     337             :  */
     338             : Datum
     339           0 : pg_ndistinct_in(PG_FUNCTION_ARGS)
     340             : {
     341           0 :     ereport(ERROR,
     342             :             (errcode(ERRCODE_FEATURE_NOT_SUPPORTED),
     343             :              errmsg("cannot accept a value of type %s", "pg_ndistinct")));
     344             : 
     345             :     PG_RETURN_VOID();           /* keep compiler quiet */
     346             : }
     347             : 
     348             : /*
     349             :  * pg_ndistinct
     350             :  *      output routine for type pg_ndistinct
     351             :  *
     352             :  * Produces a human-readable representation of the value.
     353             :  */
     354             : Datum
     355          24 : pg_ndistinct_out(PG_FUNCTION_ARGS)
     356             : {
     357          24 :     bytea      *data = PG_GETARG_BYTEA_PP(0);
     358          24 :     MVNDistinct *ndist = statext_ndistinct_deserialize(data);
     359             :     int         i;
     360             :     StringInfoData str;
     361             : 
     362          24 :     initStringInfo(&str);
     363          24 :     appendStringInfoChar(&str, '{');
     364             : 
     365         120 :     for (i = 0; i < ndist->nitems; i++)
     366             :     {
     367             :         int         j;
     368          96 :         MVNDistinctItem item = ndist->items[i];
     369             : 
     370          96 :         if (i > 0)
     371          72 :             appendStringInfoString(&str, ", ");
     372             : 
     373         312 :         for (j = 0; j < item.nattributes; j++)
     374             :         {
     375         216 :             AttrNumber  attnum = item.attributes[j];
     376             : 
     377         216 :             appendStringInfo(&str, "%s%d", (j == 0) ? "\"" : ", ", attnum);
     378             :         }
     379          96 :         appendStringInfo(&str, "\": %d", (int) item.ndistinct);
     380             :     }
     381             : 
     382          24 :     appendStringInfoChar(&str, '}');
     383             : 
     384          24 :     PG_RETURN_CSTRING(str.data);
     385             : }
     386             : 
     387             : /*
     388             :  * pg_ndistinct_recv
     389             :  *      binary input routine for type pg_ndistinct
     390             :  */
     391             : Datum
     392           0 : pg_ndistinct_recv(PG_FUNCTION_ARGS)
     393             : {
     394           0 :     ereport(ERROR,
     395             :             (errcode(ERRCODE_FEATURE_NOT_SUPPORTED),
     396             :              errmsg("cannot accept a value of type %s", "pg_ndistinct")));
     397             : 
     398             :     PG_RETURN_VOID();           /* keep compiler quiet */
     399             : }
     400             : 
     401             : /*
     402             :  * pg_ndistinct_send
     403             :  *      binary output routine for type pg_ndistinct
     404             :  *
     405             :  * n-distinct is serialized into a bytea value, so let's send that.
     406             :  */
     407             : Datum
     408           0 : pg_ndistinct_send(PG_FUNCTION_ARGS)
     409             : {
     410           0 :     return byteasend(fcinfo);
     411             : }
     412             : 
     413             : /*
     414             :  * ndistinct_for_combination
     415             :  *      Estimates number of distinct values in a combination of columns.
     416             :  *
     417             :  * This uses the same ndistinct estimator as compute_scalar_stats() in
     418             :  * ANALYZE, i.e.,
     419             :  *      n*d / (n - f1 + f1*n/N)
     420             :  *
     421             :  * except that instead of values in a single column we are dealing with
     422             :  * combination of multiple columns.
     423             :  */
     424             : static double
     425         516 : ndistinct_for_combination(double totalrows, StatsBuildData *data,
     426             :                           int k, int *combination)
     427             : {
     428             :     int         i,
     429             :                 j;
     430             :     int         f1,
     431             :                 cnt,
     432             :                 d;
     433             :     bool       *isnull;
     434             :     Datum      *values;
     435             :     SortItem   *items;
     436             :     MultiSortSupport mss;
     437         516 :     int         numrows = data->numrows;
     438             : 
     439         516 :     mss = multi_sort_init(k);
     440             : 
     441             :     /*
     442             :      * In order to determine the number of distinct elements, create separate
     443             :      * values[]/isnull[] arrays with all the data we have, then sort them
     444             :      * using the specified column combination as dimensions.  We could try to
     445             :      * sort in place, but it'd probably be more complex and bug-prone.
     446             :      */
     447         516 :     items = (SortItem *) palloc(numrows * sizeof(SortItem));
     448         516 :     values = (Datum *) palloc0(sizeof(Datum) * numrows * k);
     449         516 :     isnull = (bool *) palloc0(sizeof(bool) * numrows * k);
     450             : 
     451      483552 :     for (i = 0; i < numrows; i++)
     452             :     {
     453      483036 :         items[i].values = &values[i * k];
     454      483036 :         items[i].isnull = &isnull[i * k];
     455             :     }
     456             : 
     457             :     /*
     458             :      * For each dimension, set up sort-support and fill in the values from the
     459             :      * sample data.
     460             :      *
     461             :      * We use the column data types' default sort operators and collations;
     462             :      * perhaps at some point it'd be worth using column-specific collations?
     463             :      */
     464        1728 :     for (i = 0; i < k; i++)
     465             :     {
     466             :         Oid         typid;
     467             :         TypeCacheEntry *type;
     468        1212 :         Oid         collid = InvalidOid;
     469        1212 :         VacAttrStats *colstat = data->stats[combination[i]];
     470             : 
     471        1212 :         typid = colstat->attrtypid;
     472        1212 :         collid = colstat->attrcollid;
     473             : 
     474        1212 :         type = lookup_type_cache(typid, TYPECACHE_LT_OPR);
     475        1212 :         if (type->lt_opr == InvalidOid) /* shouldn't happen */
     476           0 :             elog(ERROR, "cache lookup failed for ordering operator for type %u",
     477             :                  typid);
     478             : 
     479             :         /* prepare the sort function for this dimension */
     480        1212 :         multi_sort_add_dimension(mss, i, type->lt_opr, collid);
     481             : 
     482             :         /* accumulate all the data for this dimension into the arrays */
     483     1147278 :         for (j = 0; j < numrows; j++)
     484             :         {
     485     1146066 :             items[j].values[i] = data->values[combination[i]][j];
     486     1146066 :             items[j].isnull[i] = data->nulls[combination[i]][j];
     487             :         }
     488             :     }
     489             : 
     490             :     /* We can sort the array now ... */
     491         516 :     qsort_interruptible(items, numrows, sizeof(SortItem),
     492             :                         multi_sort_compare, mss);
     493             : 
     494             :     /* ... and count the number of distinct combinations */
     495             : 
     496         516 :     f1 = 0;
     497         516 :     cnt = 1;
     498         516 :     d = 1;
     499      483036 :     for (i = 1; i < numrows; i++)
     500             :     {
     501      482520 :         if (multi_sort_compare(&items[i], &items[i - 1], mss) != 0)
     502             :         {
     503       46710 :             if (cnt == 1)
     504       25842 :                 f1 += 1;
     505             : 
     506       46710 :             d++;
     507       46710 :             cnt = 0;
     508             :         }
     509             : 
     510      482520 :         cnt += 1;
     511             :     }
     512             : 
     513         516 :     if (cnt == 1)
     514         102 :         f1 += 1;
     515             : 
     516         516 :     return estimate_ndistinct(totalrows, numrows, d, f1);
     517             : }
     518             : 
     519             : /* The Duj1 estimator (already used in analyze.c). */
     520             : static double
     521         516 : estimate_ndistinct(double totalrows, int numrows, int d, int f1)
     522             : {
     523             :     double      numer,
     524             :                 denom,
     525             :                 ndistinct;
     526             : 
     527         516 :     numer = (double) numrows * (double) d;
     528             : 
     529         516 :     denom = (double) (numrows - f1) +
     530         516 :         (double) f1 * (double) numrows / totalrows;
     531             : 
     532         516 :     ndistinct = numer / denom;
     533             : 
     534             :     /* Clamp to sane range in case of roundoff error */
     535         516 :     if (ndistinct < (double) d)
     536           0 :         ndistinct = (double) d;
     537             : 
     538         516 :     if (ndistinct > totalrows)
     539           0 :         ndistinct = totalrows;
     540             : 
     541         516 :     return floor(ndistinct + 0.5);
     542             : }
     543             : 
     544             : /*
     545             :  * n_choose_k
     546             :  *      computes binomial coefficients using an algorithm that is both
     547             :  *      efficient and prevents overflows
     548             :  */
     549             : static int
     550         252 : n_choose_k(int n, int k)
     551             : {
     552             :     int         d,
     553             :                 r;
     554             : 
     555             :     Assert((k > 0) && (n >= k));
     556             : 
     557             :     /* use symmetry of the binomial coefficients */
     558         252 :     k = Min(k, n - k);
     559             : 
     560         252 :     r = 1;
     561         360 :     for (d = 1; d <= k; ++d)
     562             :     {
     563         108 :         r *= n--;
     564         108 :         r /= d;
     565             :     }
     566             : 
     567         252 :     return r;
     568             : }
     569             : 
     570             : /*
     571             :  * num_combinations
     572             :  *      number of combinations, excluding single-value combinations
     573             :  */
     574             : static int
     575         168 : num_combinations(int n)
     576             : {
     577         168 :     return (1 << n) - (n + 1);
     578             : }
     579             : 
     580             : /*
     581             :  * generator_init
     582             :  *      initialize the generator of combinations
     583             :  *
     584             :  * The generator produces combinations of K elements in the interval (0..N).
     585             :  * We prebuild all the combinations in this method, which is simpler than
     586             :  * generating them on the fly.
     587             :  */
     588             : static CombinationGenerator *
     589         252 : generator_init(int n, int k)
     590             : {
     591             :     CombinationGenerator *state;
     592             : 
     593             :     Assert((n >= k) && (k > 0));
     594             : 
     595             :     /* allocate the generator state as a single chunk of memory */
     596         252 :     state = (CombinationGenerator *) palloc(sizeof(CombinationGenerator));
     597             : 
     598         252 :     state->ncombinations = n_choose_k(n, k);
     599             : 
     600             :     /* pre-allocate space for all combinations */
     601         252 :     state->combinations = (int *) palloc(sizeof(int) * k * state->ncombinations);
     602             : 
     603         252 :     state->current = 0;
     604         252 :     state->k = k;
     605         252 :     state->n = n;
     606             : 
     607             :     /* now actually pre-generate all the combinations of K elements */
     608         252 :     generate_combinations(state);
     609             : 
     610             :     /* make sure we got the expected number of combinations */
     611             :     Assert(state->current == state->ncombinations);
     612             : 
     613             :     /* reset the number, so we start with the first one */
     614         252 :     state->current = 0;
     615             : 
     616         252 :     return state;
     617             : }
     618             : 
     619             : /*
     620             :  * generator_next
     621             :  *      returns the next combination from the prebuilt list
     622             :  *
     623             :  * Returns a combination of K array indexes (0 .. N), as specified to
     624             :  * generator_init), or NULL when there are no more combination.
     625             :  */
     626             : static int *
     627         768 : generator_next(CombinationGenerator *state)
     628             : {
     629         768 :     if (state->current == state->ncombinations)
     630         252 :         return NULL;
     631             : 
     632         516 :     return &state->combinations[state->k * state->current++];
     633             : }
     634             : 
     635             : /*
     636             :  * generator_free
     637             :  *      free the internal state of the generator
     638             :  *
     639             :  * Releases the generator internal state (pre-built combinations).
     640             :  */
     641             : static void
     642         252 : generator_free(CombinationGenerator *state)
     643             : {
     644         252 :     pfree(state->combinations);
     645         252 :     pfree(state);
     646         252 : }
     647             : 
     648             : /*
     649             :  * generate_combinations_recurse
     650             :  *      given a prefix, generate all possible combinations
     651             :  *
     652             :  * Given a prefix (first few elements of the combination), generate following
     653             :  * elements recursively. We generate the combinations in lexicographic order,
     654             :  * which eliminates permutations of the same combination.
     655             :  */
     656             : static void
     657        1980 : generate_combinations_recurse(CombinationGenerator *state,
     658             :                               int index, int start, int *current)
     659             : {
     660             :     /* If we haven't filled all the elements, simply recurse. */
     661        1980 :     if (index < state->k)
     662             :     {
     663             :         int         i;
     664             : 
     665             :         /*
     666             :          * The values have to be in ascending order, so make sure we start
     667             :          * with the value passed by parameter.
     668             :          */
     669             : 
     670        3192 :         for (i = start; i < state->n; i++)
     671             :         {
     672        1728 :             current[index] = i;
     673        1728 :             generate_combinations_recurse(state, (index + 1), (i + 1), current);
     674             :         }
     675             : 
     676        1464 :         return;
     677             :     }
     678             :     else
     679             :     {
     680             :         /* we got a valid combination, add it to the array */
     681         516 :         memcpy(&state->combinations[(state->k * state->current)],
     682         516 :                current, state->k * sizeof(int));
     683         516 :         state->current++;
     684             :     }
     685             : }
     686             : 
     687             : /*
     688             :  * generate_combinations
     689             :  *      generate all k-combinations of N elements
     690             :  */
     691             : static void
     692         252 : generate_combinations(CombinationGenerator *state)
     693             : {
     694         252 :     int        *current = (int *) palloc0(sizeof(int) * state->k);
     695             : 
     696         252 :     generate_combinations_recurse(state, 0, 0, current);
     697             : 
     698         252 :     pfree(current);
     699         252 : }

Generated by: LCOV version 1.14