diff --git a/doc/src/sgml/plpgsql.sgml b/doc/src/sgml/plpgsql.sgml index e937491e6b89..cbe5c2be28b8 100644 --- a/doc/src/sgml/plpgsql.sgml +++ b/doc/src/sgml/plpgsql.sgml @@ -5388,6 +5388,25 @@ a_output := a_output || $$ if v_$$ || referrer_keys.kind || $$ like '$$ + + + strict_expr_check + + + Enabling this check will cause PL/pgSQL to + check if a PL/pgSQL expression is just an + expression without any SQL clauses like FROM, + ORDER BY. This undocumented form of expressions + is allowed for compatibility reasons, but in some special cases + it doesn't allow to detect broken code. + + + + This check is allowed only when plpgsql.extra_errors + is set to "strict_expr_check". + + + The following example shows the effect of plpgsql.extra_warnings diff --git a/src/pl/plpgsql/src/pl_comp.c b/src/pl/plpgsql/src/pl_comp.c index 519f7695d7c1..c015e4de9888 100644 --- a/src/pl/plpgsql/src/pl_comp.c +++ b/src/pl/plpgsql/src/pl_comp.c @@ -786,6 +786,13 @@ plpgsql_compile_inline(char *proc_source) function->extra_warnings = 0; function->extra_errors = 0; + /* + * Although function->extra_errors is disabled, we want to + * do strict_expr_check inside annoymous block too. + */ + if (plpgsql_extra_errors & PLPGSQL_XCHECK_STRICTEXPRCHECK) + function->extra_errors = PLPGSQL_XCHECK_STRICTEXPRCHECK; + function->nstatements = 0; function->requires_procedure_resowner = false; function->has_exception_block = false; diff --git a/src/pl/plpgsql/src/pl_gram.y b/src/pl/plpgsql/src/pl_gram.y index 5612e66d0239..dcc581afdbf2 100644 --- a/src/pl/plpgsql/src/pl_gram.y +++ b/src/pl/plpgsql/src/pl_gram.y @@ -18,6 +18,7 @@ #include "catalog/namespace.h" #include "catalog/pg_proc.h" #include "catalog/pg_type.h" +#include "nodes/nodeFuncs.h" #include "parser/parser.h" #include "parser/parse_type.h" #include "parser/scanner.h" @@ -71,6 +72,7 @@ static PLpgSQL_expr *read_sql_construct(int until, const char *expected, RawParseMode parsemode, bool isexpression, + bool allowlist, bool valid_sql, int *startloc, int *endtoken, @@ -106,7 +108,7 @@ static PLpgSQL_row *make_scalar_list1(char *initial_name, PLpgSQL_datum *initial_datum, int lineno, int location, yyscan_t yyscanner); static void check_sql_expr(const char *stmt, - RawParseMode parseMode, int location, yyscan_t yyscanner); + RawParseMode parseMode, bool allowlist, int location, yyscan_t yyscanner); static void plpgsql_sql_error_callback(void *arg); static PLpgSQL_type *parse_datatype(const char *string, int location, yyscan_t yyscanner); static void check_labels(const char *start_label, @@ -117,6 +119,7 @@ static PLpgSQL_expr *read_cursor_args(PLpgSQL_var *cursor, int until, YYSTYPE *yylvalp, YYLTYPE *yyllocp, yyscan_t yyscanner); static List *read_raise_options(YYSTYPE *yylvalp, YYLTYPE *yyllocp, yyscan_t yyscanner); static void check_raise_parameters(PLpgSQL_stmt_raise *stmt); +static bool is_strict_expr(List *parsetree, int *errpos, bool allowlist); %} @@ -193,6 +196,7 @@ static void check_raise_parameters(PLpgSQL_stmt_raise *stmt); %type expr_until_semi %type expr_until_then expr_until_loop opt_expr_until_when %type opt_exitcond +%type expressions_until_then %type cursor_variable %type decl_cursor_arg @@ -914,7 +918,7 @@ stmt_perform : K_PERFORM */ new->expr = read_sql_construct(';', 0, 0, ";", RAW_PARSE_DEFAULT, - false, false, + false, false, false, &startloc, NULL, &yylval, &yylloc, yyscanner); /* overwrite "perform" ... */ @@ -924,7 +928,7 @@ stmt_perform : K_PERFORM strlen(new->expr->query)); /* offset syntax error position to account for that */ check_sql_expr(new->expr->query, new->expr->parseMode, - startloc + 1, yyscanner); + false, startloc + 1, yyscanner); $$ = (PLpgSQL_stmt *) new; } @@ -1001,7 +1005,7 @@ stmt_assign : T_DATUM plpgsql_push_back_token(T_DATUM, &yylval, &yylloc, yyscanner); new->expr = read_sql_construct(';', 0, 0, ";", pmode, - false, true, + false, false, true, NULL, NULL, &yylval, &yylloc, yyscanner); mark_expr_as_assignment_source(new->expr, $1.datum); @@ -1262,7 +1266,7 @@ case_when_list : case_when_list case_when } ; -case_when : K_WHEN expr_until_then proc_sect +case_when : K_WHEN expressions_until_then proc_sect { PLpgSQL_case_when *new = palloc(sizeof(PLpgSQL_case_when)); @@ -1292,6 +1296,15 @@ opt_case_else : } ; +expressions_until_then : + { + $$ = read_sql_construct(K_THEN, 0, 0, "THEN", + RAW_PARSE_PLPGSQL_EXPR, /* expr_list */ + true, true, true, NULL, NULL, + &yylval, &yylloc, yyscanner); + } + ; + stmt_loop : opt_loop_label K_LOOP loop_body { PLpgSQL_stmt_loop *new; @@ -1495,6 +1508,7 @@ for_control : for_variable K_IN RAW_PARSE_DEFAULT, true, false, + false, &expr1loc, &tok, &yylval, &yylloc, yyscanner); @@ -1513,7 +1527,7 @@ for_control : for_variable K_IN */ expr1->parseMode = RAW_PARSE_PLPGSQL_EXPR; check_sql_expr(expr1->query, expr1->parseMode, - expr1loc, yyscanner); + false, expr1loc, yyscanner); /* Read and check the second one */ expr2 = read_sql_expression2(K_LOOP, K_BY, @@ -1570,7 +1584,7 @@ for_control : for_variable K_IN /* Check syntax as a regular query */ check_sql_expr(expr1->query, expr1->parseMode, - expr1loc, yyscanner); + false, expr1loc, yyscanner); new = palloc0(sizeof(PLpgSQL_stmt_fors)); new->cmd_type = PLPGSQL_STMT_FORS; @@ -1902,7 +1916,7 @@ stmt_raise : K_RAISE expr = read_sql_construct(',', ';', K_USING, ", or ; or USING", RAW_PARSE_PLPGSQL_EXPR, - true, true, + true, false, true, NULL, &tok, &yylval, &yylloc, yyscanner); new->params = lappend(new->params, expr); @@ -2040,7 +2054,7 @@ stmt_dynexecute : K_EXECUTE expr = read_sql_construct(K_INTO, K_USING, ';', "INTO or USING or ;", RAW_PARSE_PLPGSQL_EXPR, - true, true, + true, false, true, NULL, &endtoken, &yylval, &yylloc, yyscanner); @@ -2080,7 +2094,7 @@ stmt_dynexecute : K_EXECUTE expr = read_sql_construct(',', ';', K_INTO, ", or ; or INTO", RAW_PARSE_PLPGSQL_EXPR, - true, true, + true, false, true, NULL, &endtoken, &yylval, &yylloc, yyscanner); new->params = lappend(new->params, expr); @@ -2713,7 +2727,7 @@ read_sql_expression(int until, const char *expected, YYSTYPE *yylvalp, YYLTYPE * { return read_sql_construct(until, 0, 0, expected, RAW_PARSE_PLPGSQL_EXPR, - true, true, NULL, NULL, + true, false, true, NULL, NULL, yylvalp, yyllocp, yyscanner); } @@ -2724,7 +2738,7 @@ read_sql_expression2(int until, int until2, const char *expected, { return read_sql_construct(until, until2, 0, expected, RAW_PARSE_PLPGSQL_EXPR, - true, true, NULL, endtoken, + true, false, true, NULL, endtoken, yylvalp, yyllocp, yyscanner); } @@ -2734,7 +2748,7 @@ read_sql_stmt(YYSTYPE *yylvalp, YYLTYPE *yyllocp, yyscan_t yyscanner) { return read_sql_construct(';', 0, 0, ";", RAW_PARSE_DEFAULT, - false, true, NULL, NULL, + false, false, true, NULL, NULL, yylvalp, yyllocp, yyscanner); } @@ -2747,6 +2761,7 @@ read_sql_stmt(YYSTYPE *yylvalp, YYLTYPE *yyllocp, yyscan_t yyscanner) * expected: text to use in complaining that terminator was not found * parsemode: raw_parser() mode to use * isexpression: whether to say we're reading an "expression" or a "statement" + * allowlist: the result can be list of expressions * valid_sql: whether to check the syntax of the expr * startloc: if not NULL, location of first token is stored at *startloc * endtoken: if not NULL, ending token is stored at *endtoken @@ -2759,6 +2774,7 @@ read_sql_construct(int until, const char *expected, RawParseMode parsemode, bool isexpression, + bool allowlist, bool valid_sql, int *startloc, int *endtoken, @@ -2854,7 +2870,7 @@ read_sql_construct(int until, pfree(ds.data); if (valid_sql) - check_sql_expr(expr->query, expr->parseMode, startlocation, yyscanner); + check_sql_expr(expr->query, expr->parseMode, allowlist, startlocation, yyscanner); return expr; } @@ -3175,7 +3191,7 @@ make_execsql_stmt(int firsttoken, int location, PLword *word, YYSTYPE *yylvalp, expr = make_plpgsql_expr(ds.data, RAW_PARSE_DEFAULT); pfree(ds.data); - check_sql_expr(expr->query, expr->parseMode, location, yyscanner); + check_sql_expr(expr->query, expr->parseMode, false, location, yyscanner); execsql = palloc0(sizeof(PLpgSQL_stmt_execsql)); execsql->cmd_type = PLPGSQL_STMT_EXECSQL; @@ -3775,11 +3791,15 @@ make_scalar_list1(char *initial_name, * If no error cursor is provided, we'll just point at "location". */ static void -check_sql_expr(const char *stmt, RawParseMode parseMode, int location, yyscan_t yyscanner) +check_sql_expr(const char *stmt, + RawParseMode parseMode, bool allowlist, + int location, yyscan_t yyscanner) { sql_error_callback_arg cbarg; ErrorContextCallback syntax_errcontext; MemoryContext oldCxt; + List *parsetree; + int errpos; if (!plpgsql_check_syntax) return; @@ -3793,11 +3813,25 @@ check_sql_expr(const char *stmt, RawParseMode parseMode, int location, yyscan_t error_context_stack = &syntax_errcontext; oldCxt = MemoryContextSwitchTo(plpgsql_compile_tmp_cxt); - (void) raw_parser(stmt, parseMode); + parsetree = raw_parser(stmt, parseMode); MemoryContextSwitchTo(oldCxt); /* Restore former ereport callback */ error_context_stack = syntax_errcontext.previous; + + if (plpgsql_curr_compile->extra_warnings & PLPGSQL_XCHECK_STRICTEXPRCHECK || + plpgsql_curr_compile->extra_errors & PLPGSQL_XCHECK_STRICTEXPRCHECK) + { + /* do this check only for expressions */ + if (parseMode == RAW_PARSE_DEFAULT) + return; + + if (!is_strict_expr(parsetree, &errpos, allowlist)) + ereport(plpgsql_curr_compile->extra_errors & PLPGSQL_XCHECK_STRICTEXPRCHECK ? ERROR : WARNING, + (errcode(ERRCODE_SYNTAX_ERROR), + errmsg("syntax of expression is not strict"), + parser_errposition(errpos != -1 ? location + errpos : location))); + } } static void @@ -3831,6 +3865,74 @@ plpgsql_sql_error_callback(void *arg) errposition(0); } +/* + * Returns true, when the only targetList is in parsetree. Cursors + * can require list of expressions or list of named expressions. + */ +static bool +is_strict_expr(List *parsetree, int *errpos, bool allowlist) +{ + RawStmt *rawstmt; + SelectStmt *select; + int targets = 0; + ListCell *lc; + + /* Top should be RawStmt */ + rawstmt = castNode(RawStmt, linitial(parsetree)); + + if (IsA(rawstmt->stmt, SelectStmt)) + { + select = (SelectStmt *) rawstmt->stmt; + } + else if (IsA(rawstmt->stmt, PLAssignStmt)) + { + select = castNode(SelectStmt, ((PLAssignStmt *) rawstmt->stmt)->val); + } + else + elog(ERROR, "unexpected node type"); + + if (!select->targetList) + { + *errpos = -1; + return false; + } + else + *errpos = exprLocation((Node *) select->targetList); + + if (select->distinctClause || + select->fromClause || + select->whereClause || + select->groupClause || + select->groupDistinct || + select->havingClause || + select->windowClause || + select->sortClause || + select->limitOffset || + select->limitCount || + select->limitOption || + select->lockingClause) + return false; + + foreach(lc, select->targetList) + { + ResTarget *rt = castNode(ResTarget, lfirst(lc)); + + if (targets++ >= 1 && !allowlist) + { + *errpos = exprLocation((Node *) rt); + return false; + } + + if (rt->name) + { + *errpos = exprLocation((Node *) rt); + return false; + } + } + + return true; +} + /* * Parse a SQL datatype name and produce a PLpgSQL_type structure. * @@ -4014,7 +4116,7 @@ read_cursor_args(PLpgSQL_var *cursor, int until, YYSTYPE *yylvalp, YYLTYPE *yyll item = read_sql_construct(',', ')', 0, ",\" or \")", RAW_PARSE_PLPGSQL_EXPR, - true, true, + true, false, true, NULL, &endtoken, yylvalp, yyllocp, yyscanner); diff --git a/src/pl/plpgsql/src/pl_handler.c b/src/pl/plpgsql/src/pl_handler.c index e9a729299470..b3ba3163e9a6 100644 --- a/src/pl/plpgsql/src/pl_handler.c +++ b/src/pl/plpgsql/src/pl_handler.c @@ -97,6 +97,8 @@ plpgsql_extra_checks_check_hook(char **newvalue, void **extra, GucSource source) extrachecks |= PLPGSQL_XCHECK_TOOMANYROWS; else if (pg_strcasecmp(tok, "strict_multi_assignment") == 0) extrachecks |= PLPGSQL_XCHECK_STRICTMULTIASSIGNMENT; + else if (pg_strcasecmp(tok, "strict_expr_check") == 0) + extrachecks |= PLPGSQL_XCHECK_STRICTEXPRCHECK; else if (pg_strcasecmp(tok, "all") == 0 || pg_strcasecmp(tok, "none") == 0) { GUC_check_errdetail("Key word \"%s\" cannot be combined with other key words.", tok); diff --git a/src/pl/plpgsql/src/plpgsql.h b/src/pl/plpgsql/src/plpgsql.h index 41e52b8ce718..459f5f2e2232 100644 --- a/src/pl/plpgsql/src/plpgsql.h +++ b/src/pl/plpgsql/src/plpgsql.h @@ -1195,6 +1195,7 @@ extern bool plpgsql_check_asserts; #define PLPGSQL_XCHECK_SHADOWVAR (1 << 1) #define PLPGSQL_XCHECK_TOOMANYROWS (1 << 2) #define PLPGSQL_XCHECK_STRICTMULTIASSIGNMENT (1 << 3) +#define PLPGSQL_XCHECK_STRICTEXPRCHECK (1 << 4) #define PLPGSQL_XCHECK_ALL ((int) ~0) extern int plpgsql_extra_warnings; diff --git a/src/test/regress/expected/plpgsql.out b/src/test/regress/expected/plpgsql.out index d8ce39dba3c1..8f4f5cb1183f 100644 --- a/src/test/regress/expected/plpgsql.out +++ b/src/test/regress/expected/plpgsql.out @@ -3084,6 +3084,20 @@ select shadowtest(1); t (1 row) +-- test of strict expression check +set plpgsql.extra_errors to 'strict_expr_check'; +create or replace function strict_expr_check_func() +returns void as $$ +declare var int; +begin + var = 1 + delete from pg_class where false; +end; +$$ language plpgsql; +ERROR: syntax of expression is not strict +LINE 5: var = 1 + ^ +reset plpgsql.extra_errors; -- runtime extra checks set plpgsql.extra_warnings to 'too_many_rows'; do $$ diff --git a/src/test/regress/sql/plpgsql.sql b/src/test/regress/sql/plpgsql.sql index d413d995d17e..dd0d908d4220 100644 --- a/src/test/regress/sql/plpgsql.sql +++ b/src/test/regress/sql/plpgsql.sql @@ -2618,6 +2618,20 @@ declare f1 int; begin return 1; end $$ language plpgsql; select shadowtest(1); +-- test of strict expression check +set plpgsql.extra_errors to 'strict_expr_check'; + +create or replace function strict_expr_check_func() +returns void as $$ +declare var int; +begin + var = 1 + delete from pg_class where false; +end; +$$ language plpgsql; + +reset plpgsql.extra_errors; + -- runtime extra checks set plpgsql.extra_warnings to 'too_many_rows';