LCOV - code coverage report
Current view: top level - src/backend/commands - aggregatecmds.c (source / functions) Hit Total Coverage
Test: PostgreSQL 16beta1 Lines: 152 177 85.9 %
Date: 2023-05-30 16:15:03 Functions: 2 2 100.0 %
Legend: Lines: hit not hit

          Line data    Source code
       1             : /*-------------------------------------------------------------------------
       2             :  *
       3             :  * aggregatecmds.c
       4             :  *
       5             :  *    Routines for aggregate-manipulation commands
       6             :  *
       7             :  * Portions Copyright (c) 1996-2023, PostgreSQL Global Development Group
       8             :  * Portions Copyright (c) 1994, Regents of the University of California
       9             :  *
      10             :  *
      11             :  * IDENTIFICATION
      12             :  *    src/backend/commands/aggregatecmds.c
      13             :  *
      14             :  * DESCRIPTION
      15             :  *    The "DefineFoo" routines take the parse tree and pick out the
      16             :  *    appropriate arguments/flags, passing the results to the
      17             :  *    corresponding "FooDefine" routines (in src/catalog) that do
      18             :  *    the actual catalog-munging.  These routines also verify permission
      19             :  *    of the user to execute the command.
      20             :  *
      21             :  *-------------------------------------------------------------------------
      22             :  */
      23             : #include "postgres.h"
      24             : 
      25             : #include "access/htup_details.h"
      26             : #include "catalog/dependency.h"
      27             : #include "catalog/pg_aggregate.h"
      28             : #include "catalog/pg_namespace.h"
      29             : #include "catalog/pg_proc.h"
      30             : #include "catalog/pg_type.h"
      31             : #include "commands/alter.h"
      32             : #include "commands/defrem.h"
      33             : #include "miscadmin.h"
      34             : #include "parser/parse_func.h"
      35             : #include "parser/parse_type.h"
      36             : #include "utils/acl.h"
      37             : #include "utils/builtins.h"
      38             : #include "utils/lsyscache.h"
      39             : #include "utils/syscache.h"
      40             : 
      41             : 
      42             : static char extractModify(DefElem *defel);
      43             : 
      44             : 
      45             : /*
      46             :  *  DefineAggregate
      47             :  *
      48             :  * "oldstyle" signals the old (pre-8.2) style where the aggregate input type
      49             :  * is specified by a BASETYPE element in the parameters.  Otherwise,
      50             :  * "args" is a pair, whose first element is a list of FunctionParameter structs
      51             :  * defining the agg's arguments (both direct and aggregated), and whose second
      52             :  * element is an Integer node with the number of direct args, or -1 if this
      53             :  * isn't an ordered-set aggregate.
      54             :  * "parameters" is a list of DefElem representing the agg's definition clauses.
      55             :  */
      56             : ObjectAddress
      57         904 : DefineAggregate(ParseState *pstate,
      58             :                 List *name,
      59             :                 List *args,
      60             :                 bool oldstyle,
      61             :                 List *parameters,
      62             :                 bool replace)
      63             : {
      64             :     char       *aggName;
      65             :     Oid         aggNamespace;
      66             :     AclResult   aclresult;
      67         904 :     char        aggKind = AGGKIND_NORMAL;
      68         904 :     List       *transfuncName = NIL;
      69         904 :     List       *finalfuncName = NIL;
      70         904 :     List       *combinefuncName = NIL;
      71         904 :     List       *serialfuncName = NIL;
      72         904 :     List       *deserialfuncName = NIL;
      73         904 :     List       *mtransfuncName = NIL;
      74         904 :     List       *minvtransfuncName = NIL;
      75         904 :     List       *mfinalfuncName = NIL;
      76         904 :     bool        finalfuncExtraArgs = false;
      77         904 :     bool        mfinalfuncExtraArgs = false;
      78         904 :     char        finalfuncModify = 0;
      79         904 :     char        mfinalfuncModify = 0;
      80         904 :     List       *sortoperatorName = NIL;
      81         904 :     TypeName   *baseType = NULL;
      82         904 :     TypeName   *transType = NULL;
      83         904 :     TypeName   *mtransType = NULL;
      84         904 :     int32       transSpace = 0;
      85         904 :     int32       mtransSpace = 0;
      86         904 :     char       *initval = NULL;
      87         904 :     char       *minitval = NULL;
      88         904 :     char       *parallel = NULL;
      89             :     int         numArgs;
      90         904 :     int         numDirectArgs = 0;
      91             :     oidvector  *parameterTypes;
      92             :     ArrayType  *allParameterTypes;
      93             :     ArrayType  *parameterModes;
      94             :     ArrayType  *parameterNames;
      95             :     List       *parameterDefaults;
      96             :     Oid         variadicArgType;
      97             :     Oid         transTypeId;
      98         904 :     Oid         mtransTypeId = InvalidOid;
      99             :     char        transTypeType;
     100         904 :     char        mtransTypeType = 0;
     101         904 :     char        proparallel = PROPARALLEL_UNSAFE;
     102             :     ListCell   *pl;
     103             : 
     104             :     /* Convert list of names to a name and namespace */
     105         904 :     aggNamespace = QualifiedNameGetCreationNamespace(name, &aggName);
     106             : 
     107             :     /* Check we have creation rights in target namespace */
     108         904 :     aclresult = object_aclcheck(NamespaceRelationId, aggNamespace, GetUserId(), ACL_CREATE);
     109         904 :     if (aclresult != ACLCHECK_OK)
     110           0 :         aclcheck_error(aclresult, OBJECT_SCHEMA,
     111           0 :                        get_namespace_name(aggNamespace));
     112             : 
     113             :     /* Deconstruct the output of the aggr_args grammar production */
     114         904 :     if (!oldstyle)
     115             :     {
     116             :         Assert(list_length(args) == 2);
     117         542 :         numDirectArgs = intVal(lsecond(args));
     118         542 :         if (numDirectArgs >= 0)
     119          22 :             aggKind = AGGKIND_ORDERED_SET;
     120             :         else
     121         520 :             numDirectArgs = 0;
     122         542 :         args = linitial_node(List, args);
     123             :     }
     124             : 
     125             :     /* Examine aggregate's definition clauses */
     126        4456 :     foreach(pl, parameters)
     127             :     {
     128        3552 :         DefElem    *defel = lfirst_node(DefElem, pl);
     129             : 
     130             :         /*
     131             :          * sfunc1, stype1, and initcond1 are accepted as obsolete spellings
     132             :          * for sfunc, stype, initcond.
     133             :          */
     134        3552 :         if (strcmp(defel->defname, "sfunc") == 0)
     135         862 :             transfuncName = defGetQualifiedName(defel);
     136        2690 :         else if (strcmp(defel->defname, "sfunc1") == 0)
     137          30 :             transfuncName = defGetQualifiedName(defel);
     138        2660 :         else if (strcmp(defel->defname, "finalfunc") == 0)
     139         388 :             finalfuncName = defGetQualifiedName(defel);
     140        2272 :         else if (strcmp(defel->defname, "combinefunc") == 0)
     141          32 :             combinefuncName = defGetQualifiedName(defel);
     142        2240 :         else if (strcmp(defel->defname, "serialfunc") == 0)
     143          36 :             serialfuncName = defGetQualifiedName(defel);
     144        2204 :         else if (strcmp(defel->defname, "deserialfunc") == 0)
     145          30 :             deserialfuncName = defGetQualifiedName(defel);
     146        2174 :         else if (strcmp(defel->defname, "msfunc") == 0)
     147          60 :             mtransfuncName = defGetQualifiedName(defel);
     148        2114 :         else if (strcmp(defel->defname, "minvfunc") == 0)
     149          60 :             minvtransfuncName = defGetQualifiedName(defel);
     150        2054 :         else if (strcmp(defel->defname, "mfinalfunc") == 0)
     151           0 :             mfinalfuncName = defGetQualifiedName(defel);
     152        2054 :         else if (strcmp(defel->defname, "finalfunc_extra") == 0)
     153          16 :             finalfuncExtraArgs = defGetBoolean(defel);
     154        2038 :         else if (strcmp(defel->defname, "mfinalfunc_extra") == 0)
     155           0 :             mfinalfuncExtraArgs = defGetBoolean(defel);
     156        2038 :         else if (strcmp(defel->defname, "finalfunc_modify") == 0)
     157          20 :             finalfuncModify = extractModify(defel);
     158        2018 :         else if (strcmp(defel->defname, "mfinalfunc_modify") == 0)
     159           0 :             mfinalfuncModify = extractModify(defel);
     160        2018 :         else if (strcmp(defel->defname, "sortop") == 0)
     161           8 :             sortoperatorName = defGetQualifiedName(defel);
     162        2010 :         else if (strcmp(defel->defname, "basetype") == 0)
     163         350 :             baseType = defGetTypeName(defel);
     164        1660 :         else if (strcmp(defel->defname, "hypothetical") == 0)
     165             :         {
     166           8 :             if (defGetBoolean(defel))
     167             :             {
     168           8 :                 if (aggKind == AGGKIND_NORMAL)
     169           0 :                     ereport(ERROR,
     170             :                             (errcode(ERRCODE_INVALID_FUNCTION_DEFINITION),
     171             :                              errmsg("only ordered-set aggregates can be hypothetical")));
     172           8 :                 aggKind = AGGKIND_HYPOTHETICAL;
     173             :             }
     174             :         }
     175        1652 :         else if (strcmp(defel->defname, "stype") == 0)
     176         862 :             transType = defGetTypeName(defel);
     177         790 :         else if (strcmp(defel->defname, "stype1") == 0)
     178          30 :             transType = defGetTypeName(defel);
     179         760 :         else if (strcmp(defel->defname, "sspace") == 0)
     180           8 :             transSpace = defGetInt32(defel);
     181         752 :         else if (strcmp(defel->defname, "mstype") == 0)
     182          60 :             mtransType = defGetTypeName(defel);
     183         692 :         else if (strcmp(defel->defname, "msspace") == 0)
     184           0 :             mtransSpace = defGetInt32(defel);
     185         692 :         else if (strcmp(defel->defname, "initcond") == 0)
     186         558 :             initval = defGetString(defel);
     187         134 :         else if (strcmp(defel->defname, "initcond1") == 0)
     188          18 :             initval = defGetString(defel);
     189         116 :         else if (strcmp(defel->defname, "minitcond") == 0)
     190          16 :             minitval = defGetString(defel);
     191         100 :         else if (strcmp(defel->defname, "parallel") == 0)
     192          34 :             parallel = defGetString(defel);
     193             :         else
     194          66 :             ereport(WARNING,
     195             :                     (errcode(ERRCODE_SYNTAX_ERROR),
     196             :                      errmsg("aggregate attribute \"%s\" not recognized",
     197             :                             defel->defname)));
     198             :     }
     199             : 
     200             :     /*
     201             :      * make sure we have our required definitions
     202             :      */
     203         904 :     if (transType == NULL)
     204          12 :         ereport(ERROR,
     205             :                 (errcode(ERRCODE_INVALID_FUNCTION_DEFINITION),
     206             :                  errmsg("aggregate stype must be specified")));
     207         892 :     if (transfuncName == NIL)
     208           0 :         ereport(ERROR,
     209             :                 (errcode(ERRCODE_INVALID_FUNCTION_DEFINITION),
     210             :                  errmsg("aggregate sfunc must be specified")));
     211             : 
     212             :     /*
     213             :      * if mtransType is given, mtransfuncName and minvtransfuncName must be as
     214             :      * well; if not, then none of the moving-aggregate options should have
     215             :      * been given.
     216             :      */
     217         892 :     if (mtransType != NULL)
     218             :     {
     219          60 :         if (mtransfuncName == NIL)
     220           0 :             ereport(ERROR,
     221             :                     (errcode(ERRCODE_INVALID_FUNCTION_DEFINITION),
     222             :                      errmsg("aggregate msfunc must be specified when mstype is specified")));
     223          60 :         if (minvtransfuncName == NIL)
     224           0 :             ereport(ERROR,
     225             :                     (errcode(ERRCODE_INVALID_FUNCTION_DEFINITION),
     226             :                      errmsg("aggregate minvfunc must be specified when mstype is specified")));
     227             :     }
     228             :     else
     229             :     {
     230         832 :         if (mtransfuncName != NIL)
     231           0 :             ereport(ERROR,
     232             :                     (errcode(ERRCODE_INVALID_FUNCTION_DEFINITION),
     233             :                      errmsg("aggregate msfunc must not be specified without mstype")));
     234         832 :         if (minvtransfuncName != NIL)
     235           0 :             ereport(ERROR,
     236             :                     (errcode(ERRCODE_INVALID_FUNCTION_DEFINITION),
     237             :                      errmsg("aggregate minvfunc must not be specified without mstype")));
     238         832 :         if (mfinalfuncName != NIL)
     239           0 :             ereport(ERROR,
     240             :                     (errcode(ERRCODE_INVALID_FUNCTION_DEFINITION),
     241             :                      errmsg("aggregate mfinalfunc must not be specified without mstype")));
     242         832 :         if (mtransSpace != 0)
     243           0 :             ereport(ERROR,
     244             :                     (errcode(ERRCODE_INVALID_FUNCTION_DEFINITION),
     245             :                      errmsg("aggregate msspace must not be specified without mstype")));
     246         832 :         if (minitval != NULL)
     247           0 :             ereport(ERROR,
     248             :                     (errcode(ERRCODE_INVALID_FUNCTION_DEFINITION),
     249             :                      errmsg("aggregate minitcond must not be specified without mstype")));
     250             :     }
     251             : 
     252             :     /*
     253             :      * Default values for modify flags can only be determined once we know the
     254             :      * aggKind.
     255             :      */
     256         892 :     if (finalfuncModify == 0)
     257         872 :         finalfuncModify = (aggKind == AGGKIND_NORMAL) ? AGGMODIFY_READ_ONLY : AGGMODIFY_READ_WRITE;
     258         892 :     if (mfinalfuncModify == 0)
     259         892 :         mfinalfuncModify = (aggKind == AGGKIND_NORMAL) ? AGGMODIFY_READ_ONLY : AGGMODIFY_READ_WRITE;
     260             : 
     261             :     /*
     262             :      * look up the aggregate's input datatype(s).
     263             :      */
     264         892 :     if (oldstyle)
     265             :     {
     266             :         /*
     267             :          * Old style: use basetype parameter.  This supports aggregates of
     268             :          * zero or one input, with input type ANY meaning zero inputs.
     269             :          *
     270             :          * Historically we allowed the command to look like basetype = 'ANY'
     271             :          * so we must do a case-insensitive comparison for the name ANY. Ugh.
     272             :          */
     273             :         Oid         aggArgTypes[1];
     274             : 
     275         356 :         if (baseType == NULL)
     276           6 :             ereport(ERROR,
     277             :                     (errcode(ERRCODE_INVALID_FUNCTION_DEFINITION),
     278             :                      errmsg("aggregate input type must be specified")));
     279             : 
     280         350 :         if (pg_strcasecmp(TypeNameToString(baseType), "ANY") == 0)
     281             :         {
     282           6 :             numArgs = 0;
     283           6 :             aggArgTypes[0] = InvalidOid;
     284             :         }
     285             :         else
     286             :         {
     287         344 :             numArgs = 1;
     288         344 :             aggArgTypes[0] = typenameTypeId(NULL, baseType);
     289             :         }
     290         350 :         parameterTypes = buildoidvector(aggArgTypes, numArgs);
     291         350 :         allParameterTypes = NULL;
     292         350 :         parameterModes = NULL;
     293         350 :         parameterNames = NULL;
     294         350 :         parameterDefaults = NIL;
     295         350 :         variadicArgType = InvalidOid;
     296             :     }
     297             :     else
     298             :     {
     299             :         /*
     300             :          * New style: args is a list of FunctionParameters (possibly zero of
     301             :          * 'em).  We share functioncmds.c's code for processing them.
     302             :          */
     303             :         Oid         requiredResultType;
     304             : 
     305         536 :         if (baseType != NULL)
     306           0 :             ereport(ERROR,
     307             :                     (errcode(ERRCODE_INVALID_FUNCTION_DEFINITION),
     308             :                      errmsg("basetype is redundant with aggregate input type specification")));
     309             : 
     310         536 :         numArgs = list_length(args);
     311         536 :         interpret_function_parameter_list(pstate,
     312             :                                           args,
     313             :                                           InvalidOid,
     314             :                                           OBJECT_AGGREGATE,
     315             :                                           &parameterTypes,
     316             :                                           NULL,
     317             :                                           &allParameterTypes,
     318             :                                           &parameterModes,
     319             :                                           &parameterNames,
     320             :                                           NULL,
     321             :                                           &parameterDefaults,
     322             :                                           &variadicArgType,
     323             :                                           &requiredResultType);
     324             :         /* Parameter defaults are not currently allowed by the grammar */
     325             :         Assert(parameterDefaults == NIL);
     326             :         /* There shouldn't have been any OUT parameters, either */
     327             :         Assert(requiredResultType == InvalidOid);
     328             :     }
     329             : 
     330             :     /*
     331             :      * look up the aggregate's transtype.
     332             :      *
     333             :      * transtype can't be a pseudo-type, since we need to be able to store
     334             :      * values of the transtype.  However, we can allow polymorphic transtype
     335             :      * in some cases (AggregateCreate will check).  Also, we allow "internal"
     336             :      * for functions that want to pass pointers to private data structures;
     337             :      * but allow that only to superusers, since you could crash the system (or
     338             :      * worse) by connecting up incompatible internal-using functions in an
     339             :      * aggregate.
     340             :      */
     341         880 :     transTypeId = typenameTypeId(NULL, transType);
     342         880 :     transTypeType = get_typtype(transTypeId);
     343         880 :     if (transTypeType == TYPTYPE_PSEUDO &&
     344         280 :         !IsPolymorphicType(transTypeId))
     345             :     {
     346          58 :         if (transTypeId == INTERNALOID && superuser())
     347             :              /* okay */ ;
     348             :         else
     349           0 :             ereport(ERROR,
     350             :                     (errcode(ERRCODE_INVALID_FUNCTION_DEFINITION),
     351             :                      errmsg("aggregate transition data type cannot be %s",
     352             :                             format_type_be(transTypeId))));
     353             :     }
     354             : 
     355         880 :     if (serialfuncName && deserialfuncName)
     356             :     {
     357             :         /*
     358             :          * Serialization is only needed/allowed for transtype INTERNAL.
     359             :          */
     360          30 :         if (transTypeId != INTERNALOID)
     361           0 :             ereport(ERROR,
     362             :                     (errcode(ERRCODE_INVALID_FUNCTION_DEFINITION),
     363             :                      errmsg("serialization functions may be specified only when the aggregate transition data type is %s",
     364             :                             format_type_be(INTERNALOID))));
     365             :     }
     366         850 :     else if (serialfuncName || deserialfuncName)
     367             :     {
     368             :         /*
     369             :          * Cannot specify one function without the other.
     370             :          */
     371           6 :         ereport(ERROR,
     372             :                 (errcode(ERRCODE_INVALID_FUNCTION_DEFINITION),
     373             :                  errmsg("must specify both or neither of serialization and deserialization functions")));
     374             :     }
     375             : 
     376             :     /*
     377             :      * If a moving-aggregate transtype is specified, look that up.  Same
     378             :      * restrictions as for transtype.
     379             :      */
     380         874 :     if (mtransType)
     381             :     {
     382          60 :         mtransTypeId = typenameTypeId(NULL, mtransType);
     383          60 :         mtransTypeType = get_typtype(mtransTypeId);
     384          60 :         if (mtransTypeType == TYPTYPE_PSEUDO &&
     385           0 :             !IsPolymorphicType(mtransTypeId))
     386             :         {
     387           0 :             if (mtransTypeId == INTERNALOID && superuser())
     388             :                  /* okay */ ;
     389             :             else
     390           0 :                 ereport(ERROR,
     391             :                         (errcode(ERRCODE_INVALID_FUNCTION_DEFINITION),
     392             :                          errmsg("aggregate transition data type cannot be %s",
     393             :                                 format_type_be(mtransTypeId))));
     394             :         }
     395             :     }
     396             : 
     397             :     /*
     398             :      * If we have an initval, and it's not for a pseudotype (particularly a
     399             :      * polymorphic type), make sure it's acceptable to the type's input
     400             :      * function.  We will store the initval as text, because the input
     401             :      * function isn't necessarily immutable (consider "now" for timestamp),
     402             :      * and we want to use the runtime not creation-time interpretation of the
     403             :      * value.  However, if it's an incorrect value it seems much more
     404             :      * user-friendly to complain at CREATE AGGREGATE time.
     405             :      */
     406         874 :     if (initval && transTypeType != TYPTYPE_PSEUDO)
     407             :     {
     408             :         Oid         typinput,
     409             :                     typioparam;
     410             : 
     411         370 :         getTypeInputInfo(transTypeId, &typinput, &typioparam);
     412         370 :         (void) OidInputFunctionCall(typinput, initval, typioparam, -1);
     413             :     }
     414             : 
     415             :     /*
     416             :      * Likewise for moving-aggregate initval.
     417             :      */
     418         874 :     if (minitval && mtransTypeType != TYPTYPE_PSEUDO)
     419             :     {
     420             :         Oid         typinput,
     421             :                     typioparam;
     422             : 
     423          16 :         getTypeInputInfo(mtransTypeId, &typinput, &typioparam);
     424          16 :         (void) OidInputFunctionCall(typinput, minitval, typioparam, -1);
     425             :     }
     426             : 
     427         874 :     if (parallel)
     428             :     {
     429          34 :         if (strcmp(parallel, "safe") == 0)
     430          28 :             proparallel = PROPARALLEL_SAFE;
     431           6 :         else if (strcmp(parallel, "restricted") == 0)
     432           0 :             proparallel = PROPARALLEL_RESTRICTED;
     433           6 :         else if (strcmp(parallel, "unsafe") == 0)
     434           0 :             proparallel = PROPARALLEL_UNSAFE;
     435             :         else
     436           6 :             ereport(ERROR,
     437             :                     (errcode(ERRCODE_SYNTAX_ERROR),
     438             :                      errmsg("parameter \"parallel\" must be SAFE, RESTRICTED, or UNSAFE")));
     439             :     }
     440             : 
     441             :     /*
     442             :      * Most of the argument-checking is done inside of AggregateCreate
     443             :      */
     444         868 :     return AggregateCreate(aggName, /* aggregate name */
     445             :                            aggNamespace,    /* namespace */
     446             :                            replace,
     447             :                            aggKind,
     448             :                            numArgs,
     449             :                            numDirectArgs,
     450             :                            parameterTypes,
     451             :                            PointerGetDatum(allParameterTypes),
     452             :                            PointerGetDatum(parameterModes),
     453             :                            PointerGetDatum(parameterNames),
     454             :                            parameterDefaults,
     455             :                            variadicArgType,
     456             :                            transfuncName,   /* step function name */
     457             :                            finalfuncName,   /* final function name */
     458             :                            combinefuncName, /* combine function name */
     459             :                            serialfuncName,  /* serial function name */
     460             :                            deserialfuncName,    /* deserial function name */
     461             :                            mtransfuncName,  /* fwd trans function name */
     462             :                            minvtransfuncName,   /* inv trans function name */
     463             :                            mfinalfuncName,  /* final function name */
     464             :                            finalfuncExtraArgs,
     465             :                            mfinalfuncExtraArgs,
     466             :                            finalfuncModify,
     467             :                            mfinalfuncModify,
     468             :                            sortoperatorName,    /* sort operator name */
     469             :                            transTypeId, /* transition data type */
     470             :                            transSpace,  /* transition space */
     471             :                            mtransTypeId,    /* transition data type */
     472             :                            mtransSpace, /* transition space */
     473             :                            initval, /* initial condition */
     474             :                            minitval,    /* initial condition */
     475             :                            proparallel);    /* parallel safe? */
     476             : }
     477             : 
     478             : /*
     479             :  * Convert the string form of [m]finalfunc_modify to the catalog representation
     480             :  */
     481             : static char
     482          20 : extractModify(DefElem *defel)
     483             : {
     484          20 :     char       *val = defGetString(defel);
     485             : 
     486          20 :     if (strcmp(val, "read_only") == 0)
     487           0 :         return AGGMODIFY_READ_ONLY;
     488          20 :     if (strcmp(val, "shareable") == 0)
     489          14 :         return AGGMODIFY_SHAREABLE;
     490           6 :     if (strcmp(val, "read_write") == 0)
     491           6 :         return AGGMODIFY_READ_WRITE;
     492           0 :     ereport(ERROR,
     493             :             (errcode(ERRCODE_SYNTAX_ERROR),
     494             :              errmsg("parameter \"%s\" must be READ_ONLY, SHAREABLE, or READ_WRITE",
     495             :                     defel->defname)));
     496             :     return 0;                   /* keep compiler quiet */
     497             : }

Generated by: LCOV version 1.14