diff options
Diffstat (limited to 'src/backend/commands/functioncmds.c')
-rw-r--r-- | src/backend/commands/functioncmds.c | 145 |
1 files changed, 122 insertions, 23 deletions
diff --git a/src/backend/commands/functioncmds.c b/src/backend/commands/functioncmds.c index 7a4e104623b..199029b7a85 100644 --- a/src/backend/commands/functioncmds.c +++ b/src/backend/commands/functioncmds.c @@ -53,15 +53,18 @@ #include "commands/proclang.h" #include "executor/execdesc.h" #include "executor/executor.h" +#include "executor/functions.h" #include "funcapi.h" #include "miscadmin.h" #include "optimizer/optimizer.h" +#include "parser/analyze.h" #include "parser/parse_coerce.h" #include "parser/parse_collate.h" #include "parser/parse_expr.h" #include "parser/parse_func.h" #include "parser/parse_type.h" #include "pgstat.h" +#include "tcop/utility.h" #include "utils/acl.h" #include "utils/builtins.h" #include "utils/fmgroids.h" @@ -186,9 +189,11 @@ interpret_function_parameter_list(ParseState *pstate, Oid languageOid, ObjectType objtype, oidvector **parameterTypes, + List **parameterTypes_list, ArrayType **allParameterTypes, ArrayType **parameterModes, ArrayType **parameterNames, + List **inParameterNames_list, List **parameterDefaults, Oid *variadicArgType, Oid *requiredResultType) @@ -283,7 +288,11 @@ interpret_function_parameter_list(ParseState *pstate, /* handle input parameters */ if (fp->mode != FUNC_PARAM_OUT && fp->mode != FUNC_PARAM_TABLE) + { isinput = true; + if (parameterTypes_list) + *parameterTypes_list = lappend_oid(*parameterTypes_list, toid); + } /* handle signature parameters */ if (fp->mode == FUNC_PARAM_IN || fp->mode == FUNC_PARAM_INOUT || @@ -372,6 +381,9 @@ interpret_function_parameter_list(ParseState *pstate, have_names = true; } + if (inParameterNames_list) + *inParameterNames_list = lappend(*inParameterNames_list, makeString(fp->name ? fp->name : pstrdup(""))); + if (fp->defexpr) { Node *def; @@ -786,28 +798,10 @@ compute_function_attributes(ParseState *pstate, defel->defname); } - /* process required items */ if (as_item) *as = (List *) as_item->arg; - else - { - ereport(ERROR, - (errcode(ERRCODE_INVALID_FUNCTION_DEFINITION), - errmsg("no function body specified"))); - *as = NIL; /* keep compiler quiet */ - } - if (language_item) *language = strVal(language_item->arg); - else - { - ereport(ERROR, - (errcode(ERRCODE_INVALID_FUNCTION_DEFINITION), - errmsg("no language specified"))); - *language = NULL; /* keep compiler quiet */ - } - - /* process optional items */ if (transform_item) *transform = transform_item->arg; if (windowfunc_item) @@ -856,10 +850,26 @@ compute_function_attributes(ParseState *pstate, */ static void interpret_AS_clause(Oid languageOid, const char *languageName, - char *funcname, List *as, - char **prosrc_str_p, char **probin_str_p) + char *funcname, List *as, Node *sql_body_in, + List *parameterTypes, List *inParameterNames, + char **prosrc_str_p, char **probin_str_p, Node **sql_body_out) { - Assert(as != NIL); + if (!sql_body_in && !as) + ereport(ERROR, + (errcode(ERRCODE_INVALID_FUNCTION_DEFINITION), + errmsg("no function body specified"))); + + if (sql_body_in && as) + ereport(ERROR, + (errcode(ERRCODE_INVALID_FUNCTION_DEFINITION), + errmsg("duplicate function body specified"))); + + if (sql_body_in && languageOid != SQLlanguageId) + ereport(ERROR, + (errcode(ERRCODE_INVALID_FUNCTION_DEFINITION), + errmsg("inline SQL function body only valid for language SQL"))); + + *sql_body_out = NULL; if (languageOid == ClanguageId) { @@ -881,6 +891,76 @@ interpret_AS_clause(Oid languageOid, const char *languageName, *prosrc_str_p = funcname; } } + else if (sql_body_in) + { + SQLFunctionParseInfoPtr pinfo; + + pinfo = (SQLFunctionParseInfoPtr) palloc0(sizeof(SQLFunctionParseInfo)); + + pinfo->fname = funcname; + pinfo->nargs = list_length(parameterTypes); + pinfo->argtypes = (Oid *) palloc(pinfo->nargs * sizeof(Oid)); + pinfo->argnames = (char **) palloc(pinfo->nargs * sizeof(char *)); + for (int i = 0; i < list_length(parameterTypes); i++) + { + char *s = strVal(list_nth(inParameterNames, i)); + + pinfo->argtypes[i] = list_nth_oid(parameterTypes, i); + if (IsPolymorphicType(pinfo->argtypes[i])) + ereport(ERROR, + (errcode(ERRCODE_INVALID_FUNCTION_DEFINITION), + errmsg("SQL function with unquoted function body cannot have polymorphic arguments"))); + + if (s[0] != '\0') + pinfo->argnames[i] = s; + else + pinfo->argnames[i] = NULL; + } + + if (IsA(sql_body_in, List)) + { + List *stmts = linitial_node(List, castNode(List, sql_body_in)); + ListCell *lc; + List *transformed_stmts = NIL; + + foreach(lc, stmts) + { + Node *stmt = lfirst(lc); + Query *q; + ParseState *pstate = make_parsestate(NULL); + + sql_fn_parser_setup(pstate, pinfo); + q = transformStmt(pstate, stmt); + if (q->commandType == CMD_UTILITY) + ereport(ERROR, + errcode(ERRCODE_FEATURE_NOT_SUPPORTED), + errmsg("%s is not yet supported in unquoted SQL function body", + GetCommandTagName(CreateCommandTag(q->utilityStmt)))); + transformed_stmts = lappend(transformed_stmts, q); + free_parsestate(pstate); + } + + *sql_body_out = (Node *) list_make1(transformed_stmts); + } + else + { + Query *q; + ParseState *pstate = make_parsestate(NULL); + + sql_fn_parser_setup(pstate, pinfo); + q = transformStmt(pstate, sql_body_in); + if (q->commandType == CMD_UTILITY) + ereport(ERROR, + errcode(ERRCODE_FEATURE_NOT_SUPPORTED), + errmsg("%s is not yet supported in unquoted SQL function body", + GetCommandTagName(CreateCommandTag(q->utilityStmt)))); + + *sql_body_out = (Node *) q; + } + + *probin_str_p = NULL; + *prosrc_str_p = NULL; + } else { /* Everything else wants the given string in prosrc. */ @@ -919,6 +999,7 @@ CreateFunction(ParseState *pstate, CreateFunctionStmt *stmt) { char *probin_str; char *prosrc_str; + Node *prosqlbody; Oid prorettype; bool returnsSet; char *language; @@ -929,9 +1010,11 @@ CreateFunction(ParseState *pstate, CreateFunctionStmt *stmt) Oid namespaceId; AclResult aclresult; oidvector *parameterTypes; + List *parameterTypes_list = NIL; ArrayType *allParameterTypes; ArrayType *parameterModes; ArrayType *parameterNames; + List *inParameterNames_list = NIL; List *parameterDefaults; Oid variadicArgType; List *trftypes_list = NIL; @@ -962,6 +1045,8 @@ CreateFunction(ParseState *pstate, CreateFunctionStmt *stmt) get_namespace_name(namespaceId)); /* Set default attributes */ + as_clause = NIL; + language = NULL; isWindowFunc = false; isStrict = false; security = false; @@ -983,6 +1068,16 @@ CreateFunction(ParseState *pstate, CreateFunctionStmt *stmt) &proconfig, &procost, &prorows, &prosupport, ¶llel); + if (!language) + { + if (stmt->sql_body) + language = "sql"; + else + ereport(ERROR, + (errcode(ERRCODE_INVALID_FUNCTION_DEFINITION), + errmsg("no language specified"))); + } + /* Look up the language and validate permissions */ languageTuple = SearchSysCache1(LANGNAME, PointerGetDatum(language)); if (!HeapTupleIsValid(languageTuple)) @@ -1053,9 +1148,11 @@ CreateFunction(ParseState *pstate, CreateFunctionStmt *stmt) languageOid, stmt->is_procedure ? OBJECT_PROCEDURE : OBJECT_FUNCTION, ¶meterTypes, + ¶meterTypes_list, &allParameterTypes, ¶meterModes, ¶meterNames, + &inParameterNames_list, ¶meterDefaults, &variadicArgType, &requiredResultType); @@ -1112,8 +1209,9 @@ CreateFunction(ParseState *pstate, CreateFunctionStmt *stmt) trftypes = NULL; } - interpret_AS_clause(languageOid, language, funcname, as_clause, - &prosrc_str, &probin_str); + interpret_AS_clause(languageOid, language, funcname, as_clause, stmt->sql_body, + parameterTypes_list, inParameterNames_list, + &prosrc_str, &probin_str, &prosqlbody); /* * Set default values for COST and ROWS depending on other parameters; @@ -1155,6 +1253,7 @@ CreateFunction(ParseState *pstate, CreateFunctionStmt *stmt) languageValidator, prosrc_str, /* converted to text later */ probin_str, /* converted to text later */ + prosqlbody, stmt->is_procedure ? PROKIND_PROCEDURE : (isWindowFunc ? PROKIND_WINDOW : PROKIND_FUNCTION), security, isLeakProof, |