diff options
Diffstat (limited to 'contrib/postgres_fdw/connection.c')
-rw-r--r-- | contrib/postgres_fdw/connection.c | 69 |
1 files changed, 63 insertions, 6 deletions
diff --git a/contrib/postgres_fdw/connection.c b/contrib/postgres_fdw/connection.c index 202e7e583b3..0274d6c253d 100644 --- a/contrib/postgres_fdw/connection.c +++ b/contrib/postgres_fdw/connection.c @@ -19,6 +19,7 @@ #include "access/xact.h" #include "catalog/pg_user_mapping.h" #include "commands/defrem.h" +#include "common/base64.h" #include "funcapi.h" #include "libpq/libpq-be.h" #include "libpq/libpq-be-fe-helpers.h" @@ -177,6 +178,7 @@ static void pgfdw_finish_abort_cleanup(List *pending_entries, static void pgfdw_security_check(const char **keywords, const char **values, UserMapping *user, PGconn *conn); static bool UserMappingPasswordRequired(UserMapping *user); +static bool UseScramPassthrough(ForeignServer *server, UserMapping *user); static bool disconnect_cached_connections(Oid serverid); static void postgres_fdw_get_connections_internal(FunctionCallInfo fcinfo, enum pgfdwVersion api_version); @@ -485,7 +487,7 @@ connect_pg_server(ForeignServer *server, UserMapping *user) * for application_name, fallback_application_name, client_encoding, * end marker. */ - n = list_length(server->options) + list_length(user->options) + 4; + n = list_length(server->options) + list_length(user->options) + 4 + 2; keywords = (const char **) palloc(n * sizeof(char *)); values = (const char **) palloc(n * sizeof(char *)); @@ -554,10 +556,37 @@ connect_pg_server(ForeignServer *server, UserMapping *user) values[n] = GetDatabaseEncodingName(); n++; + if (MyProcPort->has_scram_keys && UseScramPassthrough(server, user)) + { + int len; + + keywords[n] = "scram_client_key"; + len = pg_b64_enc_len(sizeof(MyProcPort->scram_ClientKey)); + /* don't forget the zero-terminator */ + values[n] = palloc0(len + 1); + pg_b64_encode((const char *) MyProcPort->scram_ClientKey, + sizeof(MyProcPort->scram_ClientKey), + (char *) values[n], len); + n++; + + keywords[n] = "scram_server_key"; + len = pg_b64_enc_len(sizeof(MyProcPort->scram_ServerKey)); + /* don't forget the zero-terminator */ + values[n] = palloc0(len + 1); + pg_b64_encode((const char *) MyProcPort->scram_ServerKey, + sizeof(MyProcPort->scram_ServerKey), + (char *) values[n], len); + n++; + } + keywords[n] = values[n] = NULL; - /* verify the set of connection parameters */ - check_conn_params(keywords, values, user); + /* + * Verify the set of connection parameters only if scram pass-through + * is not being used because the password is not necessary. + */ + if (!(MyProcPort->has_scram_keys && UseScramPassthrough(server, user))) + check_conn_params(keywords, values, user); /* first time, allocate or get the custom wait event */ if (pgfdw_we_connect == 0) @@ -575,8 +604,12 @@ connect_pg_server(ForeignServer *server, UserMapping *user) server->servername), errdetail_internal("%s", pchomp(PQerrorMessage(conn))))); - /* Perform post-connection security checks */ - pgfdw_security_check(keywords, values, user, conn); + /* + * Perform post-connection security checks only if scram pass-through + * is not being used because the password is not necessary. + */ + if (!(MyProcPort->has_scram_keys && UseScramPassthrough(server, user))) + pgfdw_security_check(keywords, values, user, conn); /* Prepare new session for use */ configure_remote_session(conn); @@ -629,6 +662,30 @@ UserMappingPasswordRequired(UserMapping *user) return true; } +static bool +UseScramPassthrough(ForeignServer *server, UserMapping *user) +{ + ListCell *cell; + + foreach(cell, server->options) + { + DefElem *def = (DefElem *) lfirst(cell); + + if (strcmp(def->defname, "use_scram_passthrough") == 0) + return defGetBoolean(def); + } + + foreach(cell, user->options) + { + DefElem *def = (DefElem *) lfirst(cell); + + if (strcmp(def->defname, "use_scram_passthrough") == 0) + return defGetBoolean(def); + } + + return false; +} + /* * For non-superusers, insist that the connstr specify a password or that the * user provided their own GSSAPI delegated credentials. This @@ -666,7 +723,7 @@ check_conn_params(const char **keywords, const char **values, UserMapping *user) ereport(ERROR, (errcode(ERRCODE_S_R_E_PROHIBITED_SQL_STATEMENT_ATTEMPTED), errmsg("password or GSSAPI delegated credentials required"), - errdetail("Non-superusers must delegate GSSAPI credentials or provide a password in the user mapping."))); + errdetail("Non-superusers must delegate GSSAPI credentials, provide a password, or enable SCRAM pass-through in user mapping."))); } /* |